diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
new file mode 100644
index 000000000..82624bcd2
--- /dev/null
+++ b/.github/workflows/pre-commit.yml
@@ -0,0 +1,26 @@
+name: pre-commit
+
+on:
+  push:
+    branches: [main]
+  pull_request:
+
+jobs:
+  pre-commit:
+    runs-on: ubuntu-latest
+
+    steps:
+      - uses: actions/checkout@v4
+
+      - name: Set up Python
+        uses: actions/setup-python@v5
+        with:
+          python-version: '3.11'
+
+      - name: Install pre-commit
+        run: |
+          python -m pip install --upgrade pip
+          pip install pre-commit
+
+      - name: Run pre-commit
+        run: pre-commit run --all-files
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 000000000..c5a33dc6a
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,17 @@
+repos:
+  - repo: https://github.com/astral-sh/ruff-pre-commit
+    rev: v0.14.0
+    hooks:
+      - id: ruff
+        args:
+          - --fix
+          - --exit-non-zero-on-fix
+  - repo: local
+    hooks:
+      - id: ruff-format-with-kwargs
+        name: Ruff format with kwarg spacing
+        entry: scripts/run_ruff_format.py
+        language: python
+        types: [python]
+        additional_dependencies:
+          - ruff==0.6.9
diff --git a/pyproject.toml b/pyproject.toml
index e9b4ef763..f1491a68d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -865,3 +865,25 @@ amd = [
 homepage = "http://www.unsloth.ai"
 documentation = "https://github.com/unslothai/unsloth"
 repository = "https://github.com/unslothai/unsloth"
+
+[tool.ruff]
+target-version = "py311"
+
+[tool.ruff.lint]
+select = ["E9", "F63", "F7", "F82"]
+ignore = [
+    "E402",
+    "E722",
+    "F403",
+    "F405",
+    "F811",
+    "F821",
+    "F841",
+    "F401",
+    "E731",
+    "E741",
+    "F601",
+    "E712",
+]
+
+[tool.ruff.format]
diff --git a/scripts/enforce_kwargs_spacing.py b/scripts/enforce_kwargs_spacing.py
new file mode 100755
index 000000000..ca2ff343a
--- /dev/null
+++ b/scripts/enforce_kwargs_spacing.py
@@ -0,0 +1,179 @@
+#!/usr/bin/env python3
+"""Ensure keyword arguments use spaces around '=', prune redundant pass statements."""
+
+from __future__ import annotations
+
+import ast
+import argparse
+import io
+import sys
+import tokenize
+from collections import defaultdict
+from pathlib import Path
+
+
+def enforce_spacing(text: str) -> tuple[str, bool]:
+    """Return updated text with keyword '=' padded by spaces, plus change flag."""
+    lines = text.splitlines(keepends=True)
+    if not lines:
+        return text, False
+
+    offsets: dict[int, int] = defaultdict(int)
+    changed = False
+
+    reader = io.StringIO(text).readline
+    for token in tokenize.generate_tokens(reader):
+        if token.type != tokenize.OP or token.string != "=":
+            continue
+
+        line_index = token.start[0] - 1
+        col = token.start[1] + offsets[line_index]
+
+        if line_index < 0 or line_index >= len(lines):
+            continue
+
+        line = lines[line_index]
+        if col >= len(line) or line[col] != "=":
+            continue
+
+        line_changed = False
+
+        # Insert a space before '=' when missing and not preceded by whitespace.
+        if col > 0 and line[col - 1] not in {" ", "\t"}:
+            line = f"{line[:col]} {line[col:]}"
+            offsets[line_index] += 1
+            col += 1
+            line_changed = True
+            changed = True
+
+        # Insert a space after '=' when missing and not followed by whitespace or newline.
+        next_index = col + 1
+        if next_index < len(line) and line[next_index] not in {" ", "\t", "\n", "\r"}:
+            line = f"{line[:next_index]} {line[next_index:]}"
+            offsets[line_index] += 1
+            line_changed = True
+            changed = True
+
+        if line_changed:
+            lines[line_index] = line
+
+    if not changed:
+        return text, False
+
+    return "".join(lines), True
+
+
+def remove_redundant_passes(text: str) -> tuple[str, bool]:
+    """Drop pass statements that share a block with other executable code."""
+
+    try:
+        tree = ast.parse(text)
+    except SyntaxError:
+        return text, False
+
+    redundant: list[ast.Pass] = []
+
+    def visit(node: ast.AST) -> None:
+        for attr in ("body", "orelse", "finalbody"):
+            value = getattr(node, attr, None)
+            if not isinstance(value, list) or len(value) <= 1:
+                continue
+            for stmt in value:
+                if isinstance(stmt, ast.Pass):
+                    redundant.append(stmt)
+            for stmt in value:
+                if isinstance(stmt, ast.AST):
+                    visit(stmt)
+        handlers = getattr(node, "handlers", None)
+        if handlers:
+            for handler in handlers:
+                visit(handler)
+
+    visit(tree)
+
+    if not redundant:
+        return text, False
+
+    lines = text.splitlines(keepends=True)
+    changed = False
+
+    for node in sorted(
+        redundant, key=lambda item: (item.lineno, item.col_offset), reverse=True
+    ):
+        start = node.lineno - 1
+        end = (node.end_lineno or node.lineno) - 1
+        if start >= len(lines):
+            continue
+        changed = True
+        if start == end:
+            line = lines[start]
+            col_start = node.col_offset
+            col_end = node.end_col_offset or (col_start + 4)
+            segment = line[:col_start] + line[col_end:]
+            lines[start] = segment if segment.strip() else ""
+            continue
+
+        # Defensive fall-back for unexpected multi-line 'pass'.
+        prefix = lines[start][: node.col_offset]
+        lines[start] = prefix if prefix.strip() else ""
+        for idx in range(start + 1, end):
+            lines[idx] = ""
+        suffix = lines[end][(node.end_col_offset or 0) :]
+        lines[end] = suffix
+
+    # Normalise to ensure lines end with newlines except at EOF.
+    result_lines: list[str] = []
+    for index, line in enumerate(lines):
+        if not line:
+            continue
+        if index < len(lines) - 1 and not line.endswith("\n"):
+            result_lines.append(f"{line}\n")
+        else:
+            result_lines.append(line)
+
+    return "".join(result_lines), changed
+
+
+def process_file(path: Path) -> bool:
+    try:
+        with tokenize.open(path) as handle:
+            original = handle.read()
+            encoding = handle.encoding
+    except (OSError, SyntaxError) as exc:  # SyntaxError from tokenize on invalid python
+        print(f"Failed to read {path}: {exc}", file=sys.stderr)
+        return False
+
+    updated, changed = enforce_spacing(original)
+    updated, removed = remove_redundant_passes(updated)
+    if changed or removed:
+        path.write_text(updated, encoding=encoding)
+        return True
+    return False
+
+
+def main(argv: list[str]) -> int:
+    parser = argparse.ArgumentParser(description=__doc__)
+    parser.add_argument("files", nargs="+", help="Python files to fix")
+    args = parser.parse_args(argv)
+
+    touched: list[Path] = []
+    self_path = Path(__file__).resolve()
+
+    for entry in args.files:
+        path = Path(entry)
+        # Skip modifying this script to avoid self-edit loops.
+        if path.resolve() == self_path:
+            continue
+        if not path.exists() or path.is_dir():
+            continue
+        if process_file(path):
+            touched.append(path)
+
+    if touched:
+        for path in touched:
+            print(f"Adjusted kwarg spacing in {path}")
+    return 0
+
+
+if __name__ == "__main__":
+    sys.exit(main(sys.argv[1:]))
diff --git a/scripts/run_ruff_format.py b/scripts/run_ruff_format.py
new file mode 100755
index 000000000..5ec16cd9f
--- /dev/null
+++ b/scripts/run_ruff_format.py
@@ -0,0 +1,30 @@
+#!/usr/bin/env python3
+"""Run `ruff format` followed by kwarg spacing enforcement."""
+
+from __future__ import annotations
+
+import subprocess
+import sys
+from pathlib import Path
+
+HERE = Path(__file__).resolve().parent
+
+
+def main(argv: list[str]) -> int:
+    files = [arg for arg in argv if Path(arg).exists()]
+    if not files:
+        return 0
+
+    ruff_cmd = [sys.executable, "-m", "ruff", "format", *files]
+    ruff_proc = subprocess.run(ruff_cmd)
+    if ruff_proc.returncode != 0:
+        return ruff_proc.returncode
+
+    spacing_script = HERE / "enforce_kwargs_spacing.py"
+    spacing_cmd = [sys.executable, str(spacing_script), *files]
+    spacing_proc = subprocess.run(spacing_cmd)
+    return spacing_proc.returncode
+
+
+if __name__ == "__main__":
+    raise SystemExit(main(sys.argv[1:]))
diff --git a/tests/qlora/test_hf_qlora_train_and_merge.py b/tests/qlora/test_hf_qlora_train_and_merge.py
index 797d94018..ae975b026 100644
--- a/tests/qlora/test_hf_qlora_train_and_merge.py
+++ b/tests/qlora/test_hf_qlora_train_and_merge.py
@@ -54,37 +54,37 @@
     seed = 42
     batch_size = 5
     num_generations = 5
-    tokenizer = setup_tokenizer(model_name, fixup_funcs=[fix_llama3_tokenizer])
+    tokenizer = setup_tokenizer(model_name, fixup_funcs = [fix_llama3_tokenizer])
     temperature = 0.8
     max_new_tokens = 20
 
-    peft_config = get_peft_config(lora_rank=lora_rank, target_modules="all-linear")
-    model = setup_model(model_name, quantize=True, dtype=dtype, peft_config=peft_config)
+    peft_config = get_peft_config(lora_rank = lora_rank, target_modules = "all-linear")
+    model = setup_model(model_name, quantize = True, dtype = dtype, peft_config = peft_config)
 
     prompt = tokenizer.apply_chat_template(
-        [USER_MESSAGE], tokenize=False, add_generation_prompt=True
+        [USER_MESSAGE], tokenize = False, add_generation_prompt = True
     )
     with header_footer_context("Test Prompt and Answer"):
         print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
 
     dataset: Dataset = create_dataset(
-        tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES
+        tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
     )
     with header_footer_context("Dataset"):
         print(f"Dataset: {next(iter(dataset))}")
 
     training_args = SFTConfig(
-        output_dir=output_dir,
-        max_steps=max_steps,
-        per_device_train_batch_size=batch_size,
-        log_level="info",
-        report_to="none",
-        num_train_epochs=1,
-        logging_steps=1,
-        seed=seed,
-        bf16=dtype == torch.bfloat16,
-        fp16=dtype == torch.float16,
-        save_strategy="no",
+        output_dir = output_dir,
+        max_steps = max_steps,
+        per_device_train_batch_size = batch_size,
+        log_level = "info",
+        report_to = "none",
+        num_train_epochs = 1,
+        logging_steps = 1,
+        seed = seed,
+        bf16 = dtype == torch.bfloat16,
+        fp16 = dtype == torch.float16,
+        save_strategy = "no",
     )
 
     with header_footer_context("Train Args"):
@@ -92,7 +92,7 @@
         print(peft_config)
 
     trainer = setup_trainer(
-        model, tokenizer, dataset, training_args, peft_config=peft_config
+        model, tokenizer, dataset, training_args, peft_config = peft_config
     )
 
     with header_footer_context("Model"):
@@ -108,11 +108,11 @@
     responses = sample_responses(
         model,
         tokenizer,
-        prompt=prompt,
+        prompt = prompt,
         **generation_args,
     )
     with header_footer_context("Responses before training"):
-        check_responses(responses, answer=ANSWER, prompt=prompt)
+        check_responses(responses, answer = ANSWER, prompt = prompt)
 
     with header_footer_context("Peft Weights before training"):
         for name, stats in itertools.islice(describe_peft_weights(model), 2):
@@ -129,11 +129,11 @@
     responses = sample_responses(
         model,
         tokenizer,
-        prompt=prompt,
+        prompt = prompt,
         **generation_args,
     )
     with header_footer_context("Responses after training"):
-        check_responses(responses, answer=ANSWER, prompt=prompt)
+        check_responses(responses, answer = ANSWER, prompt = prompt)
 
     model_copy = deepcopy(model)
 
@@ -142,18 +142,18 @@
     responses = sample_responses(
         merged_model,
         tokenizer,
-        prompt=prompt,
+        prompt = prompt,
         **generation_args,
     )
     with header_footer_context("Responses after custom merging to 16bit"):
-        check_responses(responses, answer=ANSWER, prompt=prompt)
+        check_responses(responses, answer = ANSWER, prompt = prompt)
 
     merged_model_peft = model_copy.merge_and_unload()
     responses = sample_responses(
         merged_model_peft,
         tokenizer,
-        prompt=prompt,
+        prompt = prompt,
         **generation_args,
     )
     with header_footer_context("Responses after peft merge_and_unload"):
-        check_responses(responses, answer=ANSWER, prompt=prompt)
+        check_responses(responses, answer = ANSWER, prompt = prompt)
diff --git a/tests/qlora/test_unsloth_qlora_train_and_merge.py b/tests/qlora/test_unsloth_qlora_train_and_merge.py
index 59fa813fa..9040ad793 100644
--- a/tests/qlora/test_unsloth_qlora_train_and_merge.py
+++ b/tests/qlora/test_unsloth_qlora_train_and_merge.py
@@ -50,13 +50,13 @@ def get_unsloth_model_and_tokenizer(
     dtype: torch.dtype = torch.bfloat16,
 ):
     return FastLanguageModel.from_pretrained(
-        model_name=model_name,
-        max_seq_length=max_seq_length,
-        load_in_4bit=load_in_4bit,
-        fast_inference=fast_inference,
-        max_lora_rank=max_lora_rank,
-        gpu_memory_utilization=gpu_memory_utilization,
-        dtype=dtype,
+        model_name = model_name,
+        max_seq_length = max_seq_length,
+        load_in_4bit = load_in_4bit,
+        fast_inference = fast_inference,
+        max_lora_rank = max_lora_rank,
+        gpu_memory_utilization = gpu_memory_utilization,
+        dtype = dtype,
     )
 
 
@@ -69,11 +69,11 @@ def get_unsloth_peft_model(
 ):
     return FastLanguageModel.get_peft_model(
         model,
-        r=lora_rank,
-        target_modules=target_modules,
-        lora_alpha=lora_rank,
-        use_gradient_checkpointing=use_gradient_checkpointing,
-        random_state=random_state,
+        r = lora_rank,
+        target_modules = target_modules,
+        lora_alpha = lora_rank,
+        use_gradient_checkpointing = use_gradient_checkpointing,
+        random_state = random_state,
     )
 
 
@@ -101,48 +101,48 @@ def get_unsloth_peft_model(
 
     model, tokenizer = get_unsloth_model_and_tokenizer(
         model_name,
-        max_seq_length=512,
-        load_in_4bit=True,
-        fast_inference=False,
-        max_lora_rank=lora_rank,
-        dtype=dtype,
+        max_seq_length = 512,
+        load_in_4bit = True,
+        fast_inference = False,
+        max_lora_rank = lora_rank,
+        dtype = dtype,
     )
     temperature = 0.8
     max_new_tokens = 20
 
     model = get_unsloth_peft_model(
         model,
-        lora_rank=lora_rank,
-        target_modules=target_modules,
-        use_gradient_checkpointing=gradient_checkpointing,
-        random_state=seed,
+        lora_rank = lora_rank,
+        target_modules = target_modules,
+        use_gradient_checkpointing = gradient_checkpointing,
+        random_state = seed,
     )
 
     prompt = tokenizer.apply_chat_template(
-        [USER_MESSAGE], tokenize=False, add_generation_prompt=True
+        [USER_MESSAGE], tokenize = False, add_generation_prompt = True
     )
 
     with header_footer_context("Test Prompt and Answer"):
         print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
 
     dataset: Dataset = create_dataset(
-        tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES
+        tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
     )
     with header_footer_context("Dataset"):
         print(f"Dataset: {next(iter(dataset))}")
 
     training_args = SFTConfig(
-        output_dir=output_dir,
-        max_steps=max_steps,
-        per_device_train_batch_size=batch_size,
-        log_level="info",
-        report_to="none",
-        num_train_epochs=1,
-        logging_steps=1,
-        seed=seed,
-        bf16=dtype == torch.bfloat16,
-        fp16=dtype == torch.float16,
-        save_strategy="no",
+        output_dir = output_dir,
+        max_steps = max_steps,
+        per_device_train_batch_size = batch_size,
+        log_level = "info",
+        report_to = "none",
+        num_train_epochs = 1,
+        logging_steps = 1,
+        seed = seed,
+        bf16 = dtype == torch.bfloat16,
+        fp16 = dtype == torch.float16,
+        save_strategy = "no",
     )
 
     with header_footer_context("Train Args"):
@@ -163,11 +163,11 @@ def get_unsloth_peft_model(
     responses = sample_responses(
         model,
         tokenizer,
-        prompt=prompt,
+        prompt = prompt,
         **generation_args,
     )
     with header_footer_context("Responses before training"):
-        check_responses(responses, answer=ANSWER, prompt=prompt)
+        check_responses(responses, answer = ANSWER, prompt = prompt)
     with header_footer_context("Peft Weights before training"):
         for name, stats in itertools.islice(describe_peft_weights(model), 2):
             print(f"{name}:\n{stats}")
@@ -183,29 +183,29 @@ def get_unsloth_peft_model(
     responses = sample_responses(
         model,
         tokenizer,
-        prompt=prompt,
+        prompt = prompt,
         **generation_args,
     )
     with header_footer_context("Responses after training"):
-        check_responses(responses, answer=ANSWER, prompt=prompt)
+        check_responses(responses, answer = ANSWER, prompt = prompt)
 
     model.save_pretrained_merged(
         unsloth_merged_path,
         tokenizer,
-        save_method="merged_16bit",
+        save_method = "merged_16bit",
     )
     merged_model_unsloth, tokenizer = get_unsloth_model_and_tokenizer(
         unsloth_merged_path,
-        max_seq_length=512,
-        load_in_4bit=False,
-        fast_inference=False,
-        dtype=dtype,
+        max_seq_length = 512,
+        load_in_4bit = False,
+        fast_inference = False,
+        dtype = dtype,
     )
     responses = sample_responses(
         merged_model_unsloth,
         tokenizer,
-        prompt=prompt,
+        prompt = prompt,
         **generation_args,
     )
     with header_footer_context("Responses after unsloth merge to 16bit"):
-        check_responses(responses, answer=ANSWER, prompt=prompt)
+        check_responses(responses, answer = ANSWER, prompt = prompt)
diff --git a/tests/saving/gpt-oss-merge/test_merged_model.py b/tests/saving/gpt-oss-merge/test_merged_model.py
index c13a941fb..48f0ed2d3 100644
--- a/tests/saving/gpt-oss-merge/test_merged_model.py
+++ b/tests/saving/gpt-oss-merge/test_merged_model.py
@@ -6,6 +6,7 @@
 import os
 import shutil
 
+
 def safe_remove_directory(path):
     try:
         if os.path.exists(path) and os.path.isdir(path):
@@ -17,14 +18,14 @@ def safe_remove_directory(path):
     except Exception as e:
         print(f"Failed to remove directory {path}: {e}")
         return False
-pass
+
 
 print("š„ Loading the 16-bit merged model from disk...")
 merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-    model_name="./gpt-oss-finetuned-merged",
-    max_seq_length=1024,
-    load_in_4bit=True,
-    load_in_8bit=False,
+    model_name = "./gpt-oss-finetuned-merged",
+    max_seq_length = 1024,
+    load_in_4bit = True,
+    load_in_8bit = False,
 )
 print("ā
 Merged model loaded successfully.")
 
@@ -38,10 +39,12 @@ def safe_remove_directory(path):
     add_generation_prompt = True,
     return_tensors = "pt",
     return_dict = True,
-    reasoning_effort = "low", # **NEW!** Set reasoning effort to low, medium or high
+    reasoning_effort = "low",  # **NEW!** Set reasoning effort to low, medium or high
 ).to(merged_model.device)
 
-_ = merged_model.generate(**inputs, max_new_tokens = 512, streamer = TextStreamer(merged_tokenizer))
+_ = merged_model.generate(
+    **inputs, max_new_tokens = 512, streamer = TextStreamer(merged_tokenizer)
+)
 print("\nā
 Inference complete.")
 
 # --- Final Cleanup ---
@@ -51,5 +54,7 @@ def safe_remove_directory(path):
 gc.collect()
 
 safe_remove_directory("./gpt-oss-finetuned-merged")
-safe_remove_directory("./unsloth_compiled_cache") # Clean up cache created by this process
+safe_remove_directory(
+    "./unsloth_compiled_cache"
+)  # Clean up cache created by this process
 print("ā
 Final cleanup complete. Exiting inference script.")
diff --git a/tests/saving/gpt-oss-merge/train_and_merge.py b/tests/saving/gpt-oss-merge/train_and_merge.py
index b242dbc58..308d19bfb 100644
--- a/tests/saving/gpt-oss-merge/train_and_merge.py
+++ b/tests/saving/gpt-oss-merge/train_and_merge.py
@@ -7,6 +7,7 @@
 import os
 import shutil
 
+
 def safe_remove_directory(path):
     try:
         if os.path.exists(path) and os.path.isdir(path):
@@ -18,36 +19,62 @@ def safe_remove_directory(path):
     except Exception as e:
         print(f"Failed to remove directory {path}: {e}")
         return False
-pass
+
 
 # This tokenizer will be used by the mapping function
 tokenizer = None
+
+
 def formatting_prompts_func(examples):
     convos = examples["messages"]
-    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
+    texts = [
+        tokenizer.apply_chat_template(
+            convo, tokenize = False, add_generation_prompt = False
+        )
+        for convo in convos
+    ]
     return {"text": texts}
 
+
 # --- Load 4-bit Model and Train ---
 print("Loading 4-bit Mxfp4 gpt-oss model for training...")
 max_seq_length = 1024
 model, tokenizer = FastLanguageModel.from_pretrained(
-    "unsloth/gpt-oss-20b", max_seq_length=max_seq_length, load_in_4bit=True
+    "unsloth/gpt-oss-20b", max_seq_length = max_seq_length, load_in_4bit = True
 )
 
-dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train[:50]").map(
-    formatting_prompts_func, batched=True
+dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split = "train[:50]").map(
+    formatting_prompts_func, batched = True
 )
 
 model = FastLanguageModel.get_peft_model(
-    model, r=8, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
-    lora_alpha=16, use_gradient_checkpointing="unsloth", random_state=3407,
+    model,
+    r = 8,
+    target_modules = [
+        "q_proj",
+        "k_proj",
+        "v_proj",
+        "o_proj",
+        "gate_proj",
+        "up_proj",
+        "down_proj",
+    ],
+    lora_alpha = 16,
+    use_gradient_checkpointing = "unsloth",
+    random_state = 3407,
 )
 
 trainer = SFTTrainer(
-    model=model, tokenizer=tokenizer, train_dataset=dataset,
-    args=SFTConfig(
-        per_device_train_batch_size=1, gradient_accumulation_steps=4, max_steps=10,
-        learning_rate=2e-4, output_dir="outputs", report_to="none",
+    model = model,
+    tokenizer = tokenizer,
+    train_dataset = dataset,
+    args = SFTConfig(
+        per_device_train_batch_size = 1,
+        gradient_accumulation_steps = 4,
+        max_steps = 10,
+        learning_rate = 2e-4,
+        output_dir = "outputs",
+        report_to = "none",
     ),
 )
 
@@ -57,7 +84,9 @@ def formatting_prompts_func(examples):
 
 # --- Merge and Save ---
 print("\nš¾ Merging and saving the 16-bit model to './gpt-oss-finetuned-merged'...")
-model.save_pretrained_merged(save_directory="./gpt-oss-finetuned-merged", tokenizer=tokenizer)
+model.save_pretrained_merged(
+    save_directory = "./gpt-oss-finetuned-merged", tokenizer = tokenizer
+)
 print("ā
 Model merged and saved.")
 
 # --- Cleanup ---
@@ -67,5 +96,7 @@ def formatting_prompts_func(examples):
 gc.collect()
 
 safe_remove_directory("./outputs")
-safe_remove_directory("./unsloth_compiled_cache") # Clean up the cache created by this process
+safe_remove_directory(
+    "./unsloth_compiled_cache"
+)  # Clean up the cache created by this process
 print("ā
 Cleanup complete. Exiting training script.")
diff --git a/tests/saving/language_models/test_merge_4bit_validation.py b/tests/saving/language_models/test_merge_4bit_validation.py
index b135e5245..da5bc022a 100644
--- a/tests/saving/language_models/test_merge_4bit_validation.py
+++ b/tests/saving/language_models/test_merge_4bit_validation.py
@@ -12,106 +12,123 @@
 
 from tests.utils.cleanup_utils import safe_remove_directory
 
+
 def formatting_prompts_func(examples):
     convos = examples["messages"]
-    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
+    texts = [
+        tokenizer.apply_chat_template(
+            convo, tokenize = False, add_generation_prompt = False
+        )
+        for convo in convos
+    ]
     return {"text": texts}
 
-print(f"\n{'='*80}")
+
+print(f"\n{'=' * 80}")
 print("š PHASE 1: Loading Base Model and Initial Training")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 if torch.cuda.is_bf16_supported():
     compute_dtype = torch.bfloat16
-    attn_implementation = 'flash_attention_2'
+    attn_implementation = "flash_attention_2"
 else:
     compute_dtype = torch.float16
-    attn_implementation = 'sdpa'
+    attn_implementation = "sdpa"
 
 model, tokenizer = FastLanguageModel.from_pretrained(
-    model_name="unsloth/Llama-3.1-8B-Instruct",
-    max_seq_length=2048,
-    dtype=compute_dtype,
-    load_in_4bit=True,
-    load_in_8bit=False,
-    full_finetuning=False,
-    attn_implementation=attn_implementation
+    model_name = "unsloth/Llama-3.1-8B-Instruct",
+    max_seq_length = 2048,
+    dtype = compute_dtype,
+    load_in_4bit = True,
+    load_in_8bit = False,
+    full_finetuning = False,
+    attn_implementation = attn_implementation,
 )
 
 tokenizer = get_chat_template(
     tokenizer,
-    chat_template="llama-3.1",
+    chat_template = "llama-3.1",
 )
 
 # Load small dataset for quick training
-dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train[:100]")
-dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
+dataset_train = load_dataset(
+    "allenai/openassistant-guanaco-reformatted", split = "train[:100]"
+)
+dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
 
 print("ā
 Base model loaded successfully!")
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š PHASE 2: First Fine-tuning")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 model = FastLanguageModel.get_peft_model(
     model,
-    r=16,
-    target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
-    lora_alpha=16,
-    lora_dropout=0,
-    bias="none",
-    use_gradient_checkpointing="unsloth",
-    random_state=3407,
-    use_rslora=False,
-    loftq_config=None,
+    r = 16,
+    target_modules = [
+        "k_proj",
+        "q_proj",
+        "v_proj",
+        "o_proj",
+        "gate_proj",
+        "down_proj",
+        "up_proj",
+    ],
+    lora_alpha = 16,
+    lora_dropout = 0,
+    bias = "none",
+    use_gradient_checkpointing = "unsloth",
+    random_state = 3407,
+    use_rslora = False,
+    loftq_config = None,
 )
 
 from unsloth import is_bfloat16_supported
 
 trainer = SFTTrainer(
-    model=model,
-    tokenizer=tokenizer,
-    train_dataset=dataset_train,
-    dataset_text_field="text",
-    max_seq_length=2048,
-    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
-    dataset_num_proc=2,
-    packing=False,
-    args=TrainingArguments(
-        per_device_train_batch_size=2,
-        gradient_accumulation_steps=4,
-        warmup_ratio=0.1,
-        max_steps=10,  # Very short training for test
-        learning_rate=2e-4,
-        fp16=not is_bfloat16_supported(),
-        bf16=is_bfloat16_supported(),
-        logging_steps=5,
-        optim="adamw_8bit",
-        lr_scheduler_type="linear",
-        seed=3407,
-        output_dir="outputs",
-        report_to="none",
+    model = model,
+    tokenizer = tokenizer,
+    train_dataset = dataset_train,
+    dataset_text_field = "text",
+    max_seq_length = 2048,
+    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+    dataset_num_proc = 2,
+    packing = False,
+    args = TrainingArguments(
+        per_device_train_batch_size = 2,
+        gradient_accumulation_steps = 4,
+        warmup_ratio = 0.1,
+        max_steps = 10,  # Very short training for test
+        learning_rate = 2e-4,
+        fp16 = not is_bfloat16_supported(),
+        bf16 = is_bfloat16_supported(),
+        logging_steps = 5,
+        optim = "adamw_8bit",
+        lr_scheduler_type = "linear",
+        seed = 3407,
+        output_dir = "outputs",
+        report_to = "none",
     ),
 )
 
 trainer_stats = trainer.train()
 print("ā
 First fine-tuning completed!")
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š PHASE 3: Save with Forced 4bit Merge")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 model.save_pretrained_merged(
-    save_directory='./test_4bit_model',
-    tokenizer=tokenizer,
-    save_method="forced_merged_4bit"
+    save_directory = "./test_4bit_model",
+    tokenizer = tokenizer,
+    save_method = "forced_merged_4bit",
 )
 
 print("ā
 Model saved with forced 4bit merge!")
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š PHASE 4: Loading 4bit Model and Second Fine-tuning")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 # Clean up first model
 del model
@@ -120,15 +137,15 @@ def formatting_prompts_func(examples):
 
 # Load the 4bit merged model
 model_4bit, tokenizer_4bit = FastLanguageModel.from_pretrained(
-    model_name="./test_4bit_model",
-    max_seq_length=2048,
-    load_in_4bit=True,
-    load_in_8bit=False,
+    model_name = "./test_4bit_model",
+    max_seq_length = 2048,
+    load_in_4bit = True,
+    load_in_8bit = False,
 )
 
 tokenizer_4bit = get_chat_template(
     tokenizer_4bit,
-    chat_template="llama-3.1",
+    chat_template = "llama-3.1",
 )
 
 print("ā
 4bit model loaded successfully!")
@@ -136,55 +153,63 @@ def formatting_prompts_func(examples):
 # Add LoRA adapters to the 4bit model
 model_4bit = FastLanguageModel.get_peft_model(
     model_4bit,
-    r=16,
-    target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
-    lora_alpha=16,
-    lora_dropout=0,
-    bias="none",
-    use_gradient_checkpointing="unsloth",
-    random_state=3407,
-    use_rslora=False,
-    loftq_config=None,
+    r = 16,
+    target_modules = [
+        "k_proj",
+        "q_proj",
+        "v_proj",
+        "o_proj",
+        "gate_proj",
+        "down_proj",
+        "up_proj",
+    ],
+    lora_alpha = 16,
+    lora_dropout = 0,
+    bias = "none",
+    use_gradient_checkpointing = "unsloth",
+    random_state = 3407,
+    use_rslora = False,
+    loftq_config = None,
 )
 
 # Second fine-tuning
 trainer_4bit = SFTTrainer(
-    model=model_4bit,
-    tokenizer=tokenizer_4bit,
-    train_dataset=dataset_train,
-    dataset_text_field="text",
-    max_seq_length=2048,
-    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer_4bit),
-    dataset_num_proc=2,
-    packing=False,
-    args=TrainingArguments(
-        per_device_train_batch_size=2,
-        gradient_accumulation_steps=4,
-        warmup_ratio=0.1,
-        max_steps=10,  # Very short training for test
-        learning_rate=2e-4,
-        fp16=not is_bfloat16_supported(),
-        bf16=is_bfloat16_supported(),
-        logging_steps=5,
-        optim="adamw_8bit",
-        lr_scheduler_type="linear",
-        seed=3407,
-        output_dir="outputs_4bit",
-        report_to="none",
+    model = model_4bit,
+    tokenizer = tokenizer_4bit,
+    train_dataset = dataset_train,
+    dataset_text_field = "text",
+    max_seq_length = 2048,
+    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer_4bit),
+    dataset_num_proc = 2,
+    packing = False,
+    args = TrainingArguments(
+        per_device_train_batch_size = 2,
+        gradient_accumulation_steps = 4,
+        warmup_ratio = 0.1,
+        max_steps = 10,  # Very short training for test
+        learning_rate = 2e-4,
+        fp16 = not is_bfloat16_supported(),
+        bf16 = is_bfloat16_supported(),
+        logging_steps = 5,
+        optim = "adamw_8bit",
+        lr_scheduler_type = "linear",
+        seed = 3407,
+        output_dir = "outputs_4bit",
+        report_to = "none",
     ),
 )
 
 trainer_4bit.train()
 print("ā
 Second fine-tuning on 4bit model completed!")
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š PHASE 5: Testing TypeError on Regular Merge (Should Fail)")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 try:
     model_4bit.save_pretrained_merged(
-        save_directory='./test_should_fail',
-        tokenizer=tokenizer_4bit
+        save_directory = "./test_should_fail",
+        tokenizer = tokenizer_4bit,
         # No save_method specified, should default to regular merge
     )
     assert False, "Expected TypeError but merge succeeded!"
@@ -194,23 +219,23 @@ def formatting_prompts_func(examples):
     print("ā
 Correct TypeError raised for 4bit base model regular merge attempt!")
     print(f"Error message: {str(e)}")
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š PHASE 6: Successful Save with Forced 4bit Method")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 try:
     model_4bit.save_pretrained_merged(
-        save_directory='./test_4bit_second',
-        tokenizer=tokenizer_4bit,
-        save_method="forced_merged_4bit"
+        save_directory = "./test_4bit_second",
+        tokenizer = tokenizer_4bit,
+        save_method = "forced_merged_4bit",
     )
     print("ā
 Successfully saved 4bit model with forced 4bit method!")
 except Exception as e:
     assert False, f"Phase 6 failed unexpectedly: {e}"
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š CLEANUP")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 # Cleanup
 safe_remove_directory("./outputs")
diff --git a/tests/saving/language_models/test_merge_model_perplexity_llama-3.2.py b/tests/saving/language_models/test_merge_model_perplexity_llama-3.2.py
index 2d4ec8356..dd0e8c25c 100644
--- a/tests/saving/language_models/test_merge_model_perplexity_llama-3.2.py
+++ b/tests/saving/language_models/test_merge_model_perplexity_llama-3.2.py
@@ -1,7 +1,11 @@
 from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
 from unsloth.chat_templates import get_chat_template
 from trl import SFTTrainer, SFTConfig
-from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, TrainingArguments
+from transformers import (
+    DataCollatorForLanguageModeling,
+    DataCollatorForSeq2Seq,
+    TrainingArguments,
+)
 from datasets import load_dataset, Dataset
 import torch
 from tqdm import tqdm
@@ -20,15 +24,26 @@
 
 
 from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import ppl_model, add_to_comparison, print_model_comparison
+from tests.utils.perplexity_eval import (
+    ppl_model,
+    add_to_comparison,
+    print_model_comparison,
+)
+
 
 # Define helper functions outside of main
 def formatting_prompts_func(examples):
     convos = examples["messages"]
-    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
+    texts = [
+        tokenizer.apply_chat_template(
+            convo, tokenize = False, add_generation_prompt = False
+        )
+        for convo in convos
+    ]
     return {"text": texts}
 
-def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
+
+def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
     """Load model and compute perplexity in subprocess"""
     from unsloth import FastLanguageModel
     from unsloth.chat_templates import get_chat_template
@@ -36,36 +51,42 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
 
     # Load model
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_llama_text_model",
-        max_seq_length=2048,
-        load_in_4bit=load_in_4bit,
-        load_in_8bit=load_in_8bit,
+        model_name = "./unsloth_out/merged_llama_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = load_in_4bit,
+        load_in_8bit = load_in_8bit,
     )
     # Set up tokenizer
     merged_tokenizer = get_chat_template(
         merged_tokenizer,
-        chat_template="llama-3.1",
+        chat_template = "llama-3.1",
     )
 
     # Load dataset fresh in subprocess
-    dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
+    dataset_ppl = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "eval"
+    )
 
     # Format the dataset
     def formatting_prompts_func(examples):
         convos = examples["messages"]
-        texts = [merged_tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
+        texts = [
+            merged_tokenizer.apply_chat_template(
+                convo, tokenize = False, add_generation_prompt = False
+            )
+            for convo in convos
+        ]
         return {"text": texts}
 
-    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
     # Compute perplexity using the passed dataset
     ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
 
-
     # IMPORTANT: Convert to Python float if it's a tensor
     if torch.is_tensor(ppl_value):
         ppl_value = ppl_value.cpu().item()  # Move to CPU and convert to Python scalar
-    elif hasattr(ppl_value, 'item'):
+    elif hasattr(ppl_value, "item"):
         ppl_value = ppl_value.item()  # Convert numpy or other array types
     else:
         ppl_value = float(ppl_value)  # Ensure it's a float
@@ -80,87 +101,102 @@ def formatting_prompts_func(examples):
     torch.cuda.empty_cache()
     gc.collect()
 
+
 # Main execution code should be wrapped in this guard
 if __name__ == "__main__":
-    mp.set_start_method('spawn', force=True)
+    mp.set_start_method("spawn", force = True)
 
     if torch.cuda.is_bf16_supported():
         compute_dtype = torch.bfloat16
-        attn_implementation = 'flash_attention_2'
+        attn_implementation = "flash_attention_2"
     else:
         compute_dtype = torch.float16
-        attn_implementation = 'sdpa'
+        attn_implementation = "sdpa"
 
     model, tokenizer = FastLanguageModel.from_pretrained(
-        model_name="unsloth/Llama-3.2-3B-Instruct",
-        max_seq_length=2048,
-        dtype=compute_dtype,
-        load_in_4bit=True,
-        load_in_8bit=False,
-        full_finetuning=False,
-        attn_implementation=attn_implementation
+        model_name = "unsloth/Llama-3.2-3B-Instruct",
+        max_seq_length = 2048,
+        dtype = compute_dtype,
+        load_in_4bit = True,
+        load_in_8bit = False,
+        full_finetuning = False,
+        attn_implementation = attn_implementation,
     )
 
     tokenizer = get_chat_template(
         tokenizer,
-        chat_template="llama-3.1",
+        chat_template = "llama-3.1",
     )
 
     from unsloth.chat_templates import standardize_sharegpt
-    dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
-    dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
 
-    dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
-    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+    dataset_train = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "train"
+    )
+    dataset_ppl = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "eval"
+    )
+
+    dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
+    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
     add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
 
     model = FastLanguageModel.get_peft_model(
         model,
-        r=16,
-        target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
-        lora_alpha=16,
-        lora_dropout=0,
-        bias="none",
-        use_gradient_checkpointing="unsloth",
-        random_state=3407,
-        use_rslora=False,
-        loftq_config=None,
+        r = 16,
+        target_modules = [
+            "k_proj",
+            "q_proj",
+            "v_proj",
+            "o_proj",
+            "gate_proj",
+            "down_proj",
+            "up_proj",
+        ],
+        lora_alpha = 16,
+        lora_dropout = 0,
+        bias = "none",
+        use_gradient_checkpointing = "unsloth",
+        random_state = 3407,
+        use_rslora = False,
+        loftq_config = None,
     )
 
     from unsloth import is_bfloat16_supported
 
     trainer = SFTTrainer(
-        model=model,
-        tokenizer=tokenizer,
-        train_dataset=dataset_train,
-        dataset_text_field="text",
-        max_seq_length=2048,
-        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
-        dataset_num_proc=2,
-        packing=False,
-        args=TrainingArguments(
-            per_device_train_batch_size=2,
-            gradient_accumulation_steps=4,
-            warmup_ratio=0.1,
-            max_steps=10,
-            learning_rate=2e-4,
-            fp16=not is_bfloat16_supported(),
-            bf16=is_bfloat16_supported(),
-            logging_steps=50,
-            optim="adamw_8bit",
-            lr_scheduler_type="linear",
-            seed=3407,
-            output_dir="outputs",
-            report_to="none",
+        model = model,
+        tokenizer = tokenizer,
+        train_dataset = dataset_train,
+        dataset_text_field = "text",
+        max_seq_length = 2048,
+        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+        dataset_num_proc = 2,
+        packing = False,
+        args = TrainingArguments(
+            per_device_train_batch_size = 2,
+            gradient_accumulation_steps = 4,
+            warmup_ratio = 0.1,
+            max_steps = 10,
+            learning_rate = 2e-4,
+            fp16 = not is_bfloat16_supported(),
+            bf16 = is_bfloat16_supported(),
+            logging_steps = 50,
+            optim = "adamw_8bit",
+            lr_scheduler_type = "linear",
+            seed = 3407,
+            output_dir = "outputs",
+            report_to = "none",
         ),
     )
 
     from unsloth.chat_templates import train_on_responses_only
+
     trainer = train_on_responses_only(
         trainer,
-        instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
-        response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
+        instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
+        response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
     )
 
     # run training
@@ -171,8 +207,7 @@ def formatting_prompts_func(examples):
     # saving and merging the model to local disk
     print("merge and save to local disk")
     model.save_pretrained_merged(
-        save_directory='./unsloth_out/merged_llama_text_model',
-        tokenizer=tokenizer
+        save_directory = "./unsloth_out/merged_llama_text_model", tokenizer = tokenizer
     )
 
     # print("cleaning")
@@ -184,18 +219,19 @@ def formatting_prompts_func(examples):
     # load model from local disk and test
     print("Loading merged model in 4 bit for perplexity test")
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_llama_text_model",
-        max_seq_length=2048,
-        load_in_4bit=True,
-        load_in_8bit=False,
+        model_name = "./unsloth_out/merged_llama_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = True,
+        load_in_8bit = False,
     )
 
-    add_to_comparison("merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl))
-
+    add_to_comparison(
+        "merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl)
+    )
 
     print("Computing 8-bit model perplexity in subprocess...")
     result_queue = mp.Queue()
-    p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
+    p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
     p.start()
     p.join()
 
@@ -204,13 +240,16 @@ def formatting_prompts_func(examples):
 
     print("Loading merged model in 16 bit for perplexity test")
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_llama_text_model",
-        max_seq_length=2048,
-        load_in_4bit=False,
-        load_in_8bit=False,
+        model_name = "./unsloth_out/merged_llama_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = False,
+        load_in_8bit = False,
     )
 
-    add_to_comparison("merged model loaded 16bits", ppl_model(merged_model, merged_tokenizer, dataset_ppl))
+    add_to_comparison(
+        "merged model loaded 16bits",
+        ppl_model(merged_model, merged_tokenizer, dataset_ppl),
+    )
 
     print_model_comparison()
 
diff --git a/tests/saving/language_models/test_merge_model_perplexity_mistral.py b/tests/saving/language_models/test_merge_model_perplexity_mistral.py
index d1942ea7d..14e657c68 100644
--- a/tests/saving/language_models/test_merge_model_perplexity_mistral.py
+++ b/tests/saving/language_models/test_merge_model_perplexity_mistral.py
@@ -1,7 +1,11 @@
 from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
 from unsloth.chat_templates import get_chat_template
 from trl import SFTTrainer, SFTConfig
-from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, TrainingArguments
+from transformers import (
+    DataCollatorForLanguageModeling,
+    DataCollatorForSeq2Seq,
+    TrainingArguments,
+)
 from datasets import load_dataset, Dataset
 import torch
 from tqdm import tqdm
@@ -19,24 +23,24 @@
 sys.path.insert(0, str(REPO_ROOT))
 
 from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import ppl_model, add_to_comparison, print_model_comparison
+from tests.utils.perplexity_eval import (
+    ppl_model,
+    add_to_comparison,
+    print_model_comparison,
+)
 
 
-
-
-
-
-def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
+def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
     """Load model and compute perplexity in subprocess"""
     from unsloth import FastLanguageModel
     from tests.utils.perplexity_eval import ppl_model
 
     # Load model
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_mistral_text_model",
-        max_seq_length=2048,
-        load_in_4bit=load_in_4bit,
-        load_in_8bit=load_in_8bit,
+        model_name = "./unsloth_out/merged_mistral_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = load_in_4bit,
+        load_in_8bit = load_in_8bit,
     )
     # Set up tokenizer
     # merged_tokenizer = get_chat_template(
@@ -45,7 +49,9 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
     # )
 
     # Load dataset fresh in subprocess
-    dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
+    dataset_ppl = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "eval"
+    )
 
     alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
 
@@ -84,28 +90,28 @@ def formatting_prompts_func(examples):
             outputs.append(assistant_message)
 
             # Create formatted text
-            text = alpaca_prompt.format(instruction, user_message, assistant_message) + EOS_TOKEN
+            text = (
+                alpaca_prompt.format(instruction, user_message, assistant_message)
+                + EOS_TOKEN
+            )
             texts.append(text)
 
         return {
             "instruction": instructions,
             "input": inputs,
             "output": outputs,
-            "text": texts
+            "text": texts,
         }
 
-
-
-    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
     # Compute perplexity using the passed dataset
     ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
 
-
     # IMPORTANT: Convert to Python float if it's a tensor
     if torch.is_tensor(ppl_value):
         ppl_value = ppl_value.cpu().item()  # Move to CPU and convert to Python scalar
-    elif hasattr(ppl_value, 'item'):
+    elif hasattr(ppl_value, "item"):
         ppl_value = ppl_value.item()  # Convert numpy or other array types
     else:
         ppl_value = float(ppl_value)  # Ensure it's a float
@@ -120,31 +126,30 @@ def formatting_prompts_func(examples):
     torch.cuda.empty_cache()
     gc.collect()
 
+
 # Main execution code should be wrapped in this guard
 if __name__ == "__main__":
-    mp.set_start_method('spawn', force=True)
+    mp.set_start_method("spawn", force = True)
 
     if torch.cuda.is_bf16_supported():
         compute_dtype = torch.bfloat16
-        attn_implementation = 'flash_attention_2'
+        attn_implementation = "flash_attention_2"
     else:
         compute_dtype = torch.float16
-        attn_implementation = 'sdpa'
+        attn_implementation = "sdpa"
 
     model, tokenizer = FastLanguageModel.from_pretrained(
-        model_name="unsloth/mistral-7b-v0.3",
-        max_seq_length=2048,
-        dtype=compute_dtype,
-        load_in_4bit=True,
-        load_in_8bit=False,
-        full_finetuning=False,
-        attn_implementation=attn_implementation
+        model_name = "unsloth/mistral-7b-v0.3",
+        max_seq_length = 2048,
+        dtype = compute_dtype,
+        load_in_4bit = True,
+        load_in_8bit = False,
+        full_finetuning = False,
+        attn_implementation = attn_implementation,
     )
 
-
     EOS_TOKEN = tokenizer.eos_token
 
-
     alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
 
     ### Instruction:
@@ -156,7 +161,6 @@ def formatting_prompts_func(examples):
     ### Response:
     {}"""
 
-
     # Define helper functions outside of main
     def formatting_prompts_func(examples):
         instructions = []
@@ -182,64 +186,76 @@ def formatting_prompts_func(examples):
             outputs.append(assistant_message)
 
             # Create formatted text
-            text = alpaca_prompt.format(instruction, user_message, assistant_message) + EOS_TOKEN
+            text = (
+                alpaca_prompt.format(instruction, user_message, assistant_message)
+                + EOS_TOKEN
+            )
             texts.append(text)
 
-
         return {
             "instruction": instructions,
             "input": inputs,
             "output": outputs,
-            "text": texts
+            "text": texts,
         }
 
+    dataset_train = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "train"
+    )
+    dataset_ppl = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "eval"
+    )
 
-
-    dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
-    dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
-
-    dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
-    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+    dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
+    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
     add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
 
     model = FastLanguageModel.get_peft_model(
         model,
-        r=16,
-        target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
-        lora_alpha=16,
-        lora_dropout=0,
-        bias="none",
-        use_gradient_checkpointing="unsloth",
-        random_state=3407,
-        use_rslora=False,
-        loftq_config=None,
+        r = 16,
+        target_modules = [
+            "k_proj",
+            "q_proj",
+            "v_proj",
+            "o_proj",
+            "gate_proj",
+            "down_proj",
+            "up_proj",
+        ],
+        lora_alpha = 16,
+        lora_dropout = 0,
+        bias = "none",
+        use_gradient_checkpointing = "unsloth",
+        random_state = 3407,
+        use_rslora = False,
+        loftq_config = None,
     )
 
     from unsloth import is_bfloat16_supported
 
     trainer = SFTTrainer(
-        model=model,
-        tokenizer=tokenizer,
-        train_dataset=dataset_train,
-        dataset_text_field="text",
-        max_seq_length=2048,
-        dataset_num_proc=2,
-        packing=False,
-        args=TrainingArguments(
-            per_device_train_batch_size=2,
-            gradient_accumulation_steps=4,
-            warmup_ratio=0.1,
-            max_steps=200,
-            learning_rate=2e-4,
-            fp16=not is_bfloat16_supported(),
-            bf16=is_bfloat16_supported(),
-            logging_steps=50,
-            optim="adamw_8bit",
-            lr_scheduler_type="linear",
-            seed=3407,
-            output_dir="outputs",
-            report_to="none",
+        model = model,
+        tokenizer = tokenizer,
+        train_dataset = dataset_train,
+        dataset_text_field = "text",
+        max_seq_length = 2048,
+        dataset_num_proc = 2,
+        packing = False,
+        args = TrainingArguments(
+            per_device_train_batch_size = 2,
+            gradient_accumulation_steps = 4,
+            warmup_ratio = 0.1,
+            max_steps = 200,
+            learning_rate = 2e-4,
+            fp16 = not is_bfloat16_supported(),
+            bf16 = is_bfloat16_supported(),
+            logging_steps = 50,
+            optim = "adamw_8bit",
+            lr_scheduler_type = "linear",
+            seed = 3407,
+            output_dir = "outputs",
+            report_to = "none",
         ),
     )
 
@@ -251,8 +267,7 @@ def formatting_prompts_func(examples):
     # saving and merging the model to local disk
     print("merge and save to local disk")
     model.save_pretrained_merged(
-        save_directory='./unsloth_out/merged_mistral_text_model',
-        tokenizer=tokenizer
+        save_directory = "./unsloth_out/merged_mistral_text_model", tokenizer = tokenizer
     )
 
     # print("cleaning")
@@ -264,18 +279,19 @@ def formatting_prompts_func(examples):
     # load model from local disk and test
     print("Loading merged model in 4 bit for perplexity test")
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_mistral_text_model",
-        max_seq_length=2048,
-        load_in_4bit=True,
-        load_in_8bit=False,
+        model_name = "./unsloth_out/merged_mistral_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = True,
+        load_in_8bit = False,
     )
 
-    add_to_comparison("merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl))
-
+    add_to_comparison(
+        "merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl)
+    )
 
     print("Computing 8-bit model perplexity in subprocess...")
     result_queue = mp.Queue()
-    p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
+    p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
     p.start()
     p.join()
 
@@ -284,13 +300,16 @@ def formatting_prompts_func(examples):
 
     print("Loading merged model in 16 bit for perplexity test")
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_mistral_text_model",
-        max_seq_length=2048,
-        load_in_4bit=False,
-        load_in_8bit=False,
+        model_name = "./unsloth_out/merged_mistral_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = False,
+        load_in_8bit = False,
     )
 
-    add_to_comparison("merged model loaded 16bits", ppl_model(merged_model, merged_tokenizer, dataset_ppl))
+    add_to_comparison(
+        "merged model loaded 16bits",
+        ppl_model(merged_model, merged_tokenizer, dataset_ppl),
+    )
 
     print_model_comparison()
 
diff --git a/tests/saving/language_models/test_merge_model_perplexity_phi_4.py b/tests/saving/language_models/test_merge_model_perplexity_phi_4.py
index c0bd7faaf..bebea8168 100644
--- a/tests/saving/language_models/test_merge_model_perplexity_phi_4.py
+++ b/tests/saving/language_models/test_merge_model_perplexity_phi_4.py
@@ -1,7 +1,11 @@
 from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
 from unsloth.chat_templates import get_chat_template
 from trl import SFTTrainer, SFTConfig
-from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, TrainingArguments
+from transformers import (
+    DataCollatorForLanguageModeling,
+    DataCollatorForSeq2Seq,
+    TrainingArguments,
+)
 from datasets import load_dataset, Dataset
 import torch
 from tqdm import tqdm
@@ -20,7 +24,12 @@
 
 
 from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import ppl_model, add_to_comparison, print_model_comparison
+from tests.utils.perplexity_eval import (
+    ppl_model,
+    add_to_comparison,
+    print_model_comparison,
+)
+
 
 # Define helper functions outside of main
 def formatting_prompts_func(examples):
@@ -31,9 +40,12 @@ def formatting_prompts_func(examples):
         )
         for convo in convos
     ]
-    return { "text" : texts, }
+    return {
+        "text": texts,
+    }
+
 
-def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
+def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
     """Load model and compute perplexity in subprocess"""
     from unsloth import FastLanguageModel
     from unsloth.chat_templates import get_chat_template
@@ -41,36 +53,42 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
 
     # Load model
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_phi4_text_model",
-        max_seq_length=2048,
-        load_in_4bit=load_in_4bit,
-        load_in_8bit=load_in_8bit,
+        model_name = "./unsloth_out/merged_phi4_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = load_in_4bit,
+        load_in_8bit = load_in_8bit,
     )
     # Set up tokenizer
     merged_tokenizer = get_chat_template(
         merged_tokenizer,
-        chat_template="phi-4",
+        chat_template = "phi-4",
     )
 
     # Load dataset fresh in subprocess
-    dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
+    dataset_ppl = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "eval"
+    )
 
     # Format the dataset
     def formatting_prompts_func(examples):
         convos = examples["messages"]
-        texts = [merged_tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
+        texts = [
+            merged_tokenizer.apply_chat_template(
+                convo, tokenize = False, add_generation_prompt = False
+            )
+            for convo in convos
+        ]
         return {"text": texts}
 
-    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
     # Compute perplexity using the passed dataset
     ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
 
-
     # IMPORTANT: Convert to Python float if it's a tensor
     if torch.is_tensor(ppl_value):
         ppl_value = ppl_value.cpu().item()  # Move to CPU and convert to Python scalar
-    elif hasattr(ppl_value, 'item'):
+    elif hasattr(ppl_value, "item"):
         ppl_value = ppl_value.item()  # Convert numpy or other array types
     else:
         ppl_value = float(ppl_value)  # Ensure it's a float
@@ -85,86 +103,100 @@ def formatting_prompts_func(examples):
     torch.cuda.empty_cache()
     gc.collect()
 
+
 # Main execution code should be wrapped in this guard
 if __name__ == "__main__":
-    mp.set_start_method('spawn', force=True)
+    mp.set_start_method("spawn", force = True)
 
     if torch.cuda.is_bf16_supported():
         compute_dtype = torch.bfloat16
-        attn_implementation = 'flash_attention_2'
+        attn_implementation = "flash_attention_2"
     else:
         compute_dtype = torch.float16
-        attn_implementation = 'sdpa'
+        attn_implementation = "sdpa"
 
     model, tokenizer = FastLanguageModel.from_pretrained(
-        model_name="unsloth/Phi-4",
-        max_seq_length=2048,
-        dtype=compute_dtype,
-        load_in_4bit=True,
-        load_in_8bit=False,
-        full_finetuning=False,
-        attn_implementation=attn_implementation
+        model_name = "unsloth/Phi-4",
+        max_seq_length = 2048,
+        dtype = compute_dtype,
+        load_in_4bit = True,
+        load_in_8bit = False,
+        full_finetuning = False,
+        attn_implementation = attn_implementation,
     )
 
     tokenizer = get_chat_template(
         tokenizer,
-        chat_template="phi-4",
+        chat_template = "phi-4",
     )
 
-    dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
-    dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
+    dataset_train = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "train"
+    )
+    dataset_ppl = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "eval"
+    )
 
-    dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
-    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+    dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
+    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
     add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
 
     model = FastLanguageModel.get_peft_model(
         model,
-        r=16,
-        target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
-        lora_alpha=16,
-        lora_dropout=0,
-        bias="none",
-        use_gradient_checkpointing="unsloth",
-        random_state=3407,
-        use_rslora=False,
-        loftq_config=None,
+        r = 16,
+        target_modules = [
+            "k_proj",
+            "q_proj",
+            "v_proj",
+            "o_proj",
+            "gate_proj",
+            "down_proj",
+            "up_proj",
+        ],
+        lora_alpha = 16,
+        lora_dropout = 0,
+        bias = "none",
+        use_gradient_checkpointing = "unsloth",
+        random_state = 3407,
+        use_rslora = False,
+        loftq_config = None,
     )
 
     from unsloth import is_bfloat16_supported
 
     trainer = SFTTrainer(
-        model=model,
-        tokenizer=tokenizer,
-        train_dataset=dataset_train,
-        dataset_text_field="text",
-        max_seq_length=2048,
-        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
-        dataset_num_proc=2,
-        packing=False,
-        args=TrainingArguments(
-            per_device_train_batch_size=2,
-            gradient_accumulation_steps=4,
-            warmup_ratio=0.1,
-            max_steps=200,
-            learning_rate=2e-4,
-            fp16=not is_bfloat16_supported(),
-            bf16=is_bfloat16_supported(),
-            logging_steps=50,
-            optim="adamw_8bit",
-            lr_scheduler_type="linear",
-            seed=3407,
-            output_dir="outputs",
-            report_to="none",
+        model = model,
+        tokenizer = tokenizer,
+        train_dataset = dataset_train,
+        dataset_text_field = "text",
+        max_seq_length = 2048,
+        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+        dataset_num_proc = 2,
+        packing = False,
+        args = TrainingArguments(
+            per_device_train_batch_size = 2,
+            gradient_accumulation_steps = 4,
+            warmup_ratio = 0.1,
+            max_steps = 200,
+            learning_rate = 2e-4,
+            fp16 = not is_bfloat16_supported(),
+            bf16 = is_bfloat16_supported(),
+            logging_steps = 50,
+            optim = "adamw_8bit",
+            lr_scheduler_type = "linear",
+            seed = 3407,
+            output_dir = "outputs",
+            report_to = "none",
         ),
     )
 
     from unsloth.chat_templates import train_on_responses_only
+
     trainer = train_on_responses_only(
         trainer,
-        instruction_part="<|im_start|>user<|im_sep|>\n\n",
-        response_part="<|im_start|>assistant<|im_sep|>\n\n",
+        instruction_part = "<|im_start|>user<|im_sep|>\n\n",
+        response_part = "<|im_start|>assistant<|im_sep|>\n\n",
     )
 
     # run training
@@ -175,8 +207,7 @@ def formatting_prompts_func(examples):
     # saving and merging the model to local disk
     print("merge and save to local disk")
     model.save_pretrained_merged(
-        save_directory='./unsloth_out/merged_phi4_text_model',
-        tokenizer=tokenizer
+        save_directory = "./unsloth_out/merged_phi4_text_model", tokenizer = tokenizer
     )
 
     # print("cleaning")
@@ -188,18 +219,19 @@ def formatting_prompts_func(examples):
     # load model from local disk and test
     print("Loading merged model in 4 bit for perplexity test")
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_phi4_text_model",
-        max_seq_length=2048,
-        load_in_4bit=True,
-        load_in_8bit=False,
+        model_name = "./unsloth_out/merged_phi4_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = True,
+        load_in_8bit = False,
     )
 
-    add_to_comparison("merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl))
-
+    add_to_comparison(
+        "merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl)
+    )
 
     print("Computing 8-bit model perplexity in subprocess...")
     result_queue = mp.Queue()
-    p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
+    p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
     p.start()
     p.join()
 
@@ -208,13 +240,16 @@ def formatting_prompts_func(examples):
 
     print("Loading merged model in 16 bit for perplexity test")
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_phi4_text_model",
-        max_seq_length=2048,
-        load_in_4bit=False,
-        load_in_8bit=False,
+        model_name = "./unsloth_out/merged_phi4_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = False,
+        load_in_8bit = False,
     )
 
-    add_to_comparison("merged model loaded 16bits", ppl_model(merged_model, merged_tokenizer, dataset_ppl))
+    add_to_comparison(
+        "merged model loaded 16bits",
+        ppl_model(merged_model, merged_tokenizer, dataset_ppl),
+    )
 
     print_model_comparison()
 
diff --git a/tests/saving/language_models/test_merged_model_perplexity_llama-3.1-8b.py b/tests/saving/language_models/test_merged_model_perplexity_llama-3.1-8b.py
index d26771bf8..c6da9e2ca 100644
--- a/tests/saving/language_models/test_merged_model_perplexity_llama-3.1-8b.py
+++ b/tests/saving/language_models/test_merged_model_perplexity_llama-3.1-8b.py
@@ -1,7 +1,11 @@
 from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
 from unsloth.chat_templates import get_chat_template
 from trl import SFTTrainer, SFTConfig
-from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, TrainingArguments
+from transformers import (
+    DataCollatorForLanguageModeling,
+    DataCollatorForSeq2Seq,
+    TrainingArguments,
+)
 from datasets import load_dataset, Dataset
 import torch
 from tqdm import tqdm
@@ -19,15 +23,26 @@
 sys.path.insert(0, str(REPO_ROOT))
 
 from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import ppl_model, add_to_comparison, print_model_comparison
+from tests.utils.perplexity_eval import (
+    ppl_model,
+    add_to_comparison,
+    print_model_comparison,
+)
+
 
 # Define helper functions outside of main
 def formatting_prompts_func(examples):
     convos = examples["messages"]
-    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
+    texts = [
+        tokenizer.apply_chat_template(
+            convo, tokenize = False, add_generation_prompt = False
+        )
+        for convo in convos
+    ]
     return {"text": texts}
 
-def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
+
+def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
     """Load model and compute perplexity in subprocess"""
     from unsloth import FastLanguageModel
     from unsloth.chat_templates import get_chat_template
@@ -35,36 +50,42 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
 
     # Load model
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_llama_text_model",
-        max_seq_length=2048,
-        load_in_4bit=load_in_4bit,
-        load_in_8bit=load_in_8bit,
+        model_name = "./unsloth_out/merged_llama_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = load_in_4bit,
+        load_in_8bit = load_in_8bit,
     )
     # Set up tokenizer
     merged_tokenizer = get_chat_template(
         merged_tokenizer,
-        chat_template="llama-3.1",
+        chat_template = "llama-3.1",
     )
 
     # Load dataset fresh in subprocess
-    dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
+    dataset_ppl = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "eval"
+    )
 
     # Format the dataset
     def formatting_prompts_func(examples):
         convos = examples["messages"]
-        texts = [merged_tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
+        texts = [
+            merged_tokenizer.apply_chat_template(
+                convo, tokenize = False, add_generation_prompt = False
+            )
+            for convo in convos
+        ]
         return {"text": texts}
 
-    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
     # Compute perplexity using the passed dataset
     ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
 
-
     # IMPORTANT: Convert to Python float if it's a tensor
     if torch.is_tensor(ppl_value):
         ppl_value = ppl_value.cpu().item()  # Move to CPU and convert to Python scalar
-    elif hasattr(ppl_value, 'item'):
+    elif hasattr(ppl_value, "item"):
         ppl_value = ppl_value.item()  # Convert numpy or other array types
     else:
         ppl_value = float(ppl_value)  # Ensure it's a float
@@ -79,39 +100,44 @@ def formatting_prompts_func(examples):
     torch.cuda.empty_cache()
     gc.collect()
 
+
 # Main execution code should be wrapped in this guard
 if __name__ == "__main__":
-    mp.set_start_method('spawn', force=True)
+    mp.set_start_method("spawn", force = True)
 
     if torch.cuda.is_bf16_supported():
         compute_dtype = torch.bfloat16
-        attn_implementation = 'flash_attention_2'
+        attn_implementation = "flash_attention_2"
     else:
         compute_dtype = torch.float16
-        attn_implementation = 'sdpa'
+        attn_implementation = "sdpa"
 
     model, tokenizer = FastLanguageModel.from_pretrained(
-        model_name="unsloth/Llama-3.1-8B-Instruct",
-        max_seq_length=2048,
-        dtype=compute_dtype,
-        load_in_4bit=True,
-        load_in_8bit=False,
-        full_finetuning=False,
-        attn_implementation=attn_implementation
+        model_name = "unsloth/Llama-3.1-8B-Instruct",
+        max_seq_length = 2048,
+        dtype = compute_dtype,
+        load_in_4bit = True,
+        load_in_8bit = False,
+        full_finetuning = False,
+        attn_implementation = attn_implementation,
     )
 
     tokenizer = get_chat_template(
         tokenizer,
-        chat_template="llama-3.1",
+        chat_template = "llama-3.1",
     )
 
     from unsloth.chat_templates import standardize_sharegpt
-    dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
-    dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
 
-    dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
-    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+    dataset_train = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "train"
+    )
+    dataset_ppl = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "eval"
+    )
 
+    dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
+    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
     print("\n dataset sample [0]")
     print(dataset_train[0])
@@ -120,50 +146,59 @@ def formatting_prompts_func(examples):
 
     model = FastLanguageModel.get_peft_model(
         model,
-        r=16,
-        target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
-        lora_alpha=16,
-        lora_dropout=0,
-        bias="none",
-        use_gradient_checkpointing="unsloth",
-        random_state=3407,
-        use_rslora=False,
-        loftq_config=None,
+        r = 16,
+        target_modules = [
+            "k_proj",
+            "q_proj",
+            "v_proj",
+            "o_proj",
+            "gate_proj",
+            "down_proj",
+            "up_proj",
+        ],
+        lora_alpha = 16,
+        lora_dropout = 0,
+        bias = "none",
+        use_gradient_checkpointing = "unsloth",
+        random_state = 3407,
+        use_rslora = False,
+        loftq_config = None,
     )
 
     from unsloth import is_bfloat16_supported
 
     trainer = SFTTrainer(
-        model=model,
-        tokenizer=tokenizer,
-        train_dataset=dataset_train,
-        dataset_text_field="text",
-        max_seq_length=2048,
-        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
-        dataset_num_proc=2,
-        packing=False,
-        args=TrainingArguments(
-            per_device_train_batch_size=2,
-            gradient_accumulation_steps=4,
-            warmup_ratio=0.1,
-            max_steps=200,
-            learning_rate=2e-4,
-            fp16=not is_bfloat16_supported(),
-            bf16=is_bfloat16_supported(),
-            logging_steps=50,
-            optim="adamw_8bit",
-            lr_scheduler_type="linear",
-            seed=3407,
-            output_dir="outputs",
-            report_to="none",
+        model = model,
+        tokenizer = tokenizer,
+        train_dataset = dataset_train,
+        dataset_text_field = "text",
+        max_seq_length = 2048,
+        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+        dataset_num_proc = 2,
+        packing = False,
+        args = TrainingArguments(
+            per_device_train_batch_size = 2,
+            gradient_accumulation_steps = 4,
+            warmup_ratio = 0.1,
+            max_steps = 200,
+            learning_rate = 2e-4,
+            fp16 = not is_bfloat16_supported(),
+            bf16 = is_bfloat16_supported(),
+            logging_steps = 50,
+            optim = "adamw_8bit",
+            lr_scheduler_type = "linear",
+            seed = 3407,
+            output_dir = "outputs",
+            report_to = "none",
         ),
     )
 
     from unsloth.chat_templates import train_on_responses_only
+
     trainer = train_on_responses_only(
         trainer,
-        instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
-        response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
+        instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
+        response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
     )
 
     tokenizer.decode(trainer.train_dataset[0]["input_ids"])
@@ -176,8 +211,7 @@ def formatting_prompts_func(examples):
     # saving and merging the model to local disk
     print("merge and save to local disk")
     model.save_pretrained_merged(
-        save_directory='./unsloth_out/merged_llama_text_model',
-        tokenizer=tokenizer
+        save_directory = "./unsloth_out/merged_llama_text_model", tokenizer = tokenizer
     )
 
     # print("cleaning")
@@ -189,18 +223,19 @@ def formatting_prompts_func(examples):
     # load model from local disk and test
     print("Loading merged model in 4 bit for perplexity test")
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_llama_text_model",
-        max_seq_length=2048,
-        load_in_4bit=True,
-        load_in_8bit=False,
+        model_name = "./unsloth_out/merged_llama_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = True,
+        load_in_8bit = False,
     )
 
-    add_to_comparison("merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl))
-
+    add_to_comparison(
+        "merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl)
+    )
 
     print("Computing 8-bit model perplexity in subprocess...")
     result_queue = mp.Queue()
-    p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
+    p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
     p.start()
     p.join()
 
@@ -209,13 +244,16 @@ def formatting_prompts_func(examples):
 
     print("Loading merged model in 16 bit for perplexity test")
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_llama_text_model",
-        max_seq_length=2048,
-        load_in_4bit=False,
-        load_in_8bit=False,
+        model_name = "./unsloth_out/merged_llama_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = False,
+        load_in_8bit = False,
     )
 
-    add_to_comparison("merged model loaded 16bits", ppl_model(merged_model, merged_tokenizer, dataset_ppl))
+    add_to_comparison(
+        "merged model loaded 16bits",
+        ppl_model(merged_model, merged_tokenizer, dataset_ppl),
+    )
 
     print_model_comparison()
 
diff --git a/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py b/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py
index b80197b18..d63bb9fe0 100644
--- a/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py
+++ b/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py
@@ -1,7 +1,11 @@
 from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
 from unsloth.chat_templates import get_chat_template
 from trl import SFTTrainer, SFTConfig
-from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, TrainingArguments
+from transformers import (
+    DataCollatorForLanguageModeling,
+    DataCollatorForSeq2Seq,
+    TrainingArguments,
+)
 from datasets import load_dataset, Dataset
 import torch
 from tqdm import tqdm
@@ -19,7 +23,11 @@
 sys.path.insert(0, str(REPO_ROOT))
 
 from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import ppl_model, add_to_comparison, print_model_comparison
+from tests.utils.perplexity_eval import (
+    ppl_model,
+    add_to_comparison,
+    print_model_comparison,
+)
 
 
 alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
@@ -33,6 +41,7 @@
 ### Response:
 {}"""
 
+
 # Define helper functions outside of main
 def formatting_prompts_func(examples):
     instructions = []
@@ -65,21 +74,21 @@ def formatting_prompts_func(examples):
         "instruction": instructions,
         "input": inputs,
         "output": outputs,
-        "text": texts
+        "text": texts,
     }
 
 
-def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
+def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
     """Load model and compute perplexity in subprocess"""
     from unsloth import FastLanguageModel
     from tests.utils.perplexity_eval import ppl_model
 
     # Load model
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_qwen_text_model",
-        max_seq_length=2048,
-        load_in_4bit=load_in_4bit,
-        load_in_8bit=load_in_8bit,
+        model_name = "./unsloth_out/merged_qwen_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = load_in_4bit,
+        load_in_8bit = load_in_8bit,
     )
     # Set up tokenizer
     # merged_tokenizer = get_chat_template(
@@ -88,7 +97,9 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
     # )
 
     # Load dataset fresh in subprocess
-    dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
+    dataset_ppl = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "eval"
+    )
 
     alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
 
@@ -132,21 +143,18 @@ def formatting_prompts_func(examples):
             "instruction": instructions,
             "input": inputs,
             "output": outputs,
-            "text": texts
+            "text": texts,
         }
 
-
-
-    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
     # Compute perplexity using the passed dataset
     ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
 
-
     # IMPORTANT: Convert to Python float if it's a tensor
     if torch.is_tensor(ppl_value):
         ppl_value = ppl_value.cpu().item()  # Move to CPU and convert to Python scalar
-    elif hasattr(ppl_value, 'item'):
+    elif hasattr(ppl_value, "item"):
         ppl_value = ppl_value.item()  # Convert numpy or other array types
     else:
         ppl_value = float(ppl_value)  # Ensure it's a float
@@ -161,74 +169,86 @@ def formatting_prompts_func(examples):
     # torch.cuda.empty_cache()
     # gc.collect()
 
+
 # Main execution code should be wrapped in this guard
 if __name__ == "__main__":
-    mp.set_start_method('spawn', force=True)
+    mp.set_start_method("spawn", force = True)
 
     if torch.cuda.is_bf16_supported():
         compute_dtype = torch.bfloat16
-        attn_implementation = 'flash_attention_2'
+        attn_implementation = "flash_attention_2"
     else:
         compute_dtype = torch.float16
-        attn_implementation = 'sdpa'
+        attn_implementation = "sdpa"
 
     model, tokenizer = FastLanguageModel.from_pretrained(
-        model_name="unsloth/Qwen2.5-7B-Instruct",
-        max_seq_length=2048,
-        dtype=compute_dtype,
-        load_in_4bit=True,
-        load_in_8bit=False,
-        full_finetuning=False,
-        attn_implementation=attn_implementation
+        model_name = "unsloth/Qwen2.5-7B-Instruct",
+        max_seq_length = 2048,
+        dtype = compute_dtype,
+        load_in_4bit = True,
+        load_in_8bit = False,
+        full_finetuning = False,
+        attn_implementation = attn_implementation,
     )
 
+    dataset_train = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "train"
+    )
+    dataset_ppl = load_dataset(
+        "allenai/openassistant-guanaco-reformatted", split = "eval"
+    )
 
-    dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
-    dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
-
-    dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
-    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+    dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
+    dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
     add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
 
     model = FastLanguageModel.get_peft_model(
         model,
-        r=16,
-        target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
-        lora_alpha=16,
-        lora_dropout=0,
-        bias="none",
-        use_gradient_checkpointing="unsloth",
-        random_state=3407,
-        use_rslora=False,
-        loftq_config=None,
+        r = 16,
+        target_modules = [
+            "k_proj",
+            "q_proj",
+            "v_proj",
+            "o_proj",
+            "gate_proj",
+            "down_proj",
+            "up_proj",
+        ],
+        lora_alpha = 16,
+        lora_dropout = 0,
+        bias = "none",
+        use_gradient_checkpointing = "unsloth",
+        random_state = 3407,
+        use_rslora = False,
+        loftq_config = None,
     )
 
     from unsloth import is_bfloat16_supported
 
     trainer = SFTTrainer(
-        model=model,
-        tokenizer=tokenizer,
-        train_dataset=dataset_train,
-        dataset_text_field="text",
-        max_seq_length=2048,
-        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
-        dataset_num_proc=2,
-        packing=False,
-        args=TrainingArguments(
-            per_device_train_batch_size=2,
-            gradient_accumulation_steps=4,
-            warmup_ratio=0.1,
-            max_steps=200,
-            learning_rate=2e-4,
-            fp16=not is_bfloat16_supported(),
-            bf16=is_bfloat16_supported(),
-            logging_steps=50,
-            optim="adamw_8bit",
-            lr_scheduler_type="linear",
-            seed=3407,
-            output_dir="outputs",
-            report_to="none",
+        model = model,
+        tokenizer = tokenizer,
+        train_dataset = dataset_train,
+        dataset_text_field = "text",
+        max_seq_length = 2048,
+        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+        dataset_num_proc = 2,
+        packing = False,
+        args = TrainingArguments(
+            per_device_train_batch_size = 2,
+            gradient_accumulation_steps = 4,
+            warmup_ratio = 0.1,
+            max_steps = 200,
+            learning_rate = 2e-4,
+            fp16 = not is_bfloat16_supported(),
+            bf16 = is_bfloat16_supported(),
+            logging_steps = 50,
+            optim = "adamw_8bit",
+            lr_scheduler_type = "linear",
+            seed = 3407,
+            output_dir = "outputs",
+            report_to = "none",
         ),
     )
 
@@ -240,8 +260,7 @@ def formatting_prompts_func(examples):
     # saving and merging the model to local disk
     print("merge and save to local disk")
     model.save_pretrained_merged(
-        save_directory='./unsloth_out/merged_qwen_text_model',
-        tokenizer=tokenizer
+        save_directory = "./unsloth_out/merged_qwen_text_model", tokenizer = tokenizer
     )
 
     # print("cleaning")
@@ -253,18 +272,19 @@ def formatting_prompts_func(examples):
     # load model from local disk and test
     print("Loading merged model in 4 bit for perplexity test")
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_qwen_text_model",
-        max_seq_length=2048,
-        load_in_4bit=True,
-        load_in_8bit=False,
+        model_name = "./unsloth_out/merged_qwen_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = True,
+        load_in_8bit = False,
     )
 
-    add_to_comparison("merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl))
-
+    add_to_comparison(
+        "merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl)
+    )
 
     print("Computing 8-bit model perplexity in subprocess...")
     result_queue = mp.Queue()
-    p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
+    p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
     p.start()
     p.join()
 
@@ -273,13 +293,16 @@ def formatting_prompts_func(examples):
 
     print("Loading merged model in 16 bit for perplexity test")
     merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
-        model_name="./unsloth_out/merged_qwen_text_model",
-        max_seq_length=2048,
-        load_in_4bit=False,
-        load_in_8bit=False,
+        model_name = "./unsloth_out/merged_qwen_text_model",
+        max_seq_length = 2048,
+        load_in_4bit = False,
+        load_in_8bit = False,
     )
 
-    add_to_comparison("merged model loaded 16bits", ppl_model(merged_model, merged_tokenizer, dataset_ppl))
+    add_to_comparison(
+        "merged model loaded 16bits",
+        ppl_model(merged_model, merged_tokenizer, dataset_ppl),
+    )
 
     print_model_comparison()
 
diff --git a/tests/saving/language_models/test_push_to_hub_merged.py b/tests/saving/language_models/test_push_to_hub_merged.py
index e1c23fa47..58d589305 100644
--- a/tests/saving/language_models/test_push_to_hub_merged.py
+++ b/tests/saving/language_models/test_push_to_hub_merged.py
@@ -1,7 +1,11 @@
 from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
 from unsloth.chat_templates import get_chat_template
 from trl import SFTTrainer, SFTConfig
-from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, TrainingArguments
+from transformers import (
+    DataCollatorForLanguageModeling,
+    DataCollatorForSeq2Seq,
+    TrainingArguments,
+)
 from datasets import load_dataset, Dataset
 import torch
 from tqdm import tqdm
@@ -11,6 +15,7 @@
 import gc
 import os
 from huggingface_hub import HfFileSystem, hf_hub_download
+
 # ruff: noqa
 import sys
 from pathlib import Path
@@ -20,94 +25,112 @@
 sys.path.insert(0, str(REPO_ROOT))
 
 from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import ppl_model, add_to_comparison, print_model_comparison
+from tests.utils.perplexity_eval import (
+    ppl_model,
+    add_to_comparison,
+    print_model_comparison,
+)
 
 
 # Define helper functions outside of main
 def formatting_prompts_func(examples):
     convos = examples["messages"]
-    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
+    texts = [
+        tokenizer.apply_chat_template(
+            convo, tokenize = False, add_generation_prompt = False
+        )
+        for convo in convos
+    ]
     return {"text": texts}
 
 
-
 if torch.cuda.is_bf16_supported():
     compute_dtype = torch.bfloat16
-    attn_implementation = 'flash_attention_2'
+    attn_implementation = "flash_attention_2"
 else:
     compute_dtype = torch.float16
-    attn_implementation = 'sdpa'
+    attn_implementation = "sdpa"
 
 model, tokenizer = FastLanguageModel.from_pretrained(
-    model_name="unsloth/Llama-3.2-1B-Instruct",
-    max_seq_length=2048,
-    dtype=compute_dtype,
-    load_in_4bit=True,
-    load_in_8bit=False,
-    full_finetuning=False,
-    attn_implementation=attn_implementation
+    model_name = "unsloth/Llama-3.2-1B-Instruct",
+    max_seq_length = 2048,
+    dtype = compute_dtype,
+    load_in_4bit = True,
+    load_in_8bit = False,
+    full_finetuning = False,
+    attn_implementation = attn_implementation,
 )
 
 tokenizer = get_chat_template(
     tokenizer,
-    chat_template="llama-3.1",
+    chat_template = "llama-3.1",
 )
 
 from unsloth.chat_templates import standardize_sharegpt
-dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
-dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
 
-dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
-dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train")
+dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval")
+
+dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
+dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
 add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
 
 model = FastLanguageModel.get_peft_model(
     model,
-    r=16,
-    target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
-    lora_alpha=16,
-    lora_dropout=0,
-    bias="none",
-    use_gradient_checkpointing="unsloth",
-    random_state=3407,
-    use_rslora=False,
-    loftq_config=None,
+    r = 16,
+    target_modules = [
+        "k_proj",
+        "q_proj",
+        "v_proj",
+        "o_proj",
+        "gate_proj",
+        "down_proj",
+        "up_proj",
+    ],
+    lora_alpha = 16,
+    lora_dropout = 0,
+    bias = "none",
+    use_gradient_checkpointing = "unsloth",
+    random_state = 3407,
+    use_rslora = False,
+    loftq_config = None,
 )
 
 from unsloth import is_bfloat16_supported
 
 trainer = SFTTrainer(
-    model=model,
-    tokenizer=tokenizer,
-    train_dataset=dataset_train,
-    dataset_text_field="text",
-    max_seq_length=2048,
-    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
-    dataset_num_proc=2,
-    packing=False,
-    args=TrainingArguments(
-        per_device_train_batch_size=2,
-        gradient_accumulation_steps=4,
-        warmup_ratio=0.1,
-        max_steps=30,
-        learning_rate=2e-4,
-        fp16=not is_bfloat16_supported(),
-        bf16=is_bfloat16_supported(),
-        logging_steps=50,
-        optim="adamw_8bit",
-        lr_scheduler_type="linear",
-        seed=3407,
-        output_dir="outputs",
-        report_to="none",
+    model = model,
+    tokenizer = tokenizer,
+    train_dataset = dataset_train,
+    dataset_text_field = "text",
+    max_seq_length = 2048,
+    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+    dataset_num_proc = 2,
+    packing = False,
+    args = TrainingArguments(
+        per_device_train_batch_size = 2,
+        gradient_accumulation_steps = 4,
+        warmup_ratio = 0.1,
+        max_steps = 30,
+        learning_rate = 2e-4,
+        fp16 = not is_bfloat16_supported(),
+        bf16 = is_bfloat16_supported(),
+        logging_steps = 50,
+        optim = "adamw_8bit",
+        lr_scheduler_type = "linear",
+        seed = 3407,
+        output_dir = "outputs",
+        report_to = "none",
     ),
 )
 
 from unsloth.chat_templates import train_on_responses_only
+
 trainer = train_on_responses_only(
     trainer,
-    instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
-    response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
+    instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
+    response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
 )
 
 # run training
@@ -128,16 +151,16 @@ def formatting_prompts_func(examples):
 
 repo_name = f"{hf_username}/merged_llama_text_model"
 success = {
-        "upload": False,
-        "download": False,
- }
+    "upload": False,
+    "download": False,
+}
 
 # Stage 1: Upload model to Hub
 try:
     print("\n" + "=" * 80)
     print("=== UPLOADING MODEL TO HUB ===".center(80))
     print("=" * 80 + "\n")
-    model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
+    model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
     success["upload"] = True
     print("ā
 Model uploaded successfully!")
 except Exception as e:
@@ -153,7 +176,9 @@ def formatting_prompts_func(examples):
     print("=== TESTING MODEL DOWNLOAD ===".center(80))
     print("=" * 80 + "\n")
     # Force download even if cached
-    model,tokenizer = FastLanguageModel.from_pretrained(f"{hf_username}/merged_llama_text_model")
+    model, tokenizer = FastLanguageModel.from_pretrained(
+        f"{hf_username}/merged_llama_text_model"
+    )
     success["download"] = True
     print("ā
 Model downloaded successfully!")
 except Exception as e:
diff --git a/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py b/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py
index 04bbf2924..038565d17 100644
--- a/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py
+++ b/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py
@@ -1,7 +1,11 @@
 from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
 from unsloth.chat_templates import get_chat_template
 from trl import SFTTrainer, SFTConfig
-from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, TrainingArguments
+from transformers import (
+    DataCollatorForLanguageModeling,
+    DataCollatorForSeq2Seq,
+    TrainingArguments,
+)
 from datasets import load_dataset, Dataset
 import torch
 from tqdm import tqdm
@@ -11,6 +15,7 @@
 import gc
 import os
 from huggingface_hub import HfFileSystem, hf_hub_download
+
 # ruff: noqa
 import sys
 from pathlib import Path
@@ -20,94 +25,112 @@
 sys.path.insert(0, str(REPO_ROOT))
 
 from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import ppl_model, add_to_comparison, print_model_comparison
+from tests.utils.perplexity_eval import (
+    ppl_model,
+    add_to_comparison,
+    print_model_comparison,
+)
 
 
 # Define helper functions outside of main
 def formatting_prompts_func(examples):
     convos = examples["messages"]
-    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
+    texts = [
+        tokenizer.apply_chat_template(
+            convo, tokenize = False, add_generation_prompt = False
+        )
+        for convo in convos
+    ]
     return {"text": texts}
 
 
-
 if torch.cuda.is_bf16_supported():
     compute_dtype = torch.bfloat16
-    attn_implementation = 'flash_attention_2'
+    attn_implementation = "flash_attention_2"
 else:
     compute_dtype = torch.float16
-    attn_implementation = 'sdpa'
+    attn_implementation = "sdpa"
 
 model, tokenizer = FastLanguageModel.from_pretrained(
-    model_name="unsloth/Llama-3.1-8B-Instruct",
-    max_seq_length=2048,
-    dtype=compute_dtype,
-    load_in_4bit=True,
-    load_in_8bit=False,
-    full_finetuning=False,
-    attn_implementation=attn_implementation
+    model_name = "unsloth/Llama-3.1-8B-Instruct",
+    max_seq_length = 2048,
+    dtype = compute_dtype,
+    load_in_4bit = True,
+    load_in_8bit = False,
+    full_finetuning = False,
+    attn_implementation = attn_implementation,
 )
 
 tokenizer = get_chat_template(
     tokenizer,
-    chat_template="llama-3.1",
+    chat_template = "llama-3.1",
 )
 
 from unsloth.chat_templates import standardize_sharegpt
-dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
-dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
 
-dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
-dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
+dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train")
+dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval")
+
+dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
+dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
 
 add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
 
 model = FastLanguageModel.get_peft_model(
     model,
-    r=16,
-    target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
-    lora_alpha=16,
-    lora_dropout=0,
-    bias="none",
-    use_gradient_checkpointing="unsloth",
-    random_state=3407,
-    use_rslora=False,
-    loftq_config=None,
+    r = 16,
+    target_modules = [
+        "k_proj",
+        "q_proj",
+        "v_proj",
+        "o_proj",
+        "gate_proj",
+        "down_proj",
+        "up_proj",
+    ],
+    lora_alpha = 16,
+    lora_dropout = 0,
+    bias = "none",
+    use_gradient_checkpointing = "unsloth",
+    random_state = 3407,
+    use_rslora = False,
+    loftq_config = None,
 )
 
 from unsloth import is_bfloat16_supported
 
 trainer = SFTTrainer(
-    model=model,
-    tokenizer=tokenizer,
-    train_dataset=dataset_train,
-    dataset_text_field="text",
-    max_seq_length=2048,
-    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
-    dataset_num_proc=2,
-    packing=False,
-    args=TrainingArguments(
-        per_device_train_batch_size=2,
-        gradient_accumulation_steps=4,
-        warmup_ratio=0.1,
-        max_steps=30,
-        learning_rate=2e-4,
-        fp16=not is_bfloat16_supported(),
-        bf16=is_bfloat16_supported(),
-        logging_steps=50,
-        optim="adamw_8bit",
-        lr_scheduler_type="linear",
-        seed=3407,
-        output_dir="outputs",
-        report_to="none",
+    model = model,
+    tokenizer = tokenizer,
+    train_dataset = dataset_train,
+    dataset_text_field = "text",
+    max_seq_length = 2048,
+    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+    dataset_num_proc = 2,
+    packing = False,
+    args = TrainingArguments(
+        per_device_train_batch_size = 2,
+        gradient_accumulation_steps = 4,
+        warmup_ratio = 0.1,
+        max_steps = 30,
+        learning_rate = 2e-4,
+        fp16 = not is_bfloat16_supported(),
+        bf16 = is_bfloat16_supported(),
+        logging_steps = 50,
+        optim = "adamw_8bit",
+        lr_scheduler_type = "linear",
+        seed = 3407,
+        output_dir = "outputs",
+        report_to = "none",
     ),
 )
 
 from unsloth.chat_templates import train_on_responses_only
+
 trainer = train_on_responses_only(
     trainer,
-    instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
-    response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
+    instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
+    response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
 )
 
 # run training
@@ -128,17 +151,17 @@ def formatting_prompts_func(examples):
 
 repo_name = f"{hf_username}/merged_llama_text_model"
 success = {
-        "upload": False,
-        "safetensors_check": False,
-        "download": False,
- }
+    "upload": False,
+    "safetensors_check": False,
+    "download": False,
+}
 
 # Stage 1: Upload model to Hub
 try:
     print("\n" + "=" * 80)
     print("=== UPLOADING MODEL TO HUB ===".center(80))
     print("=" * 80 + "\n")
-    model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
+    model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
     success["upload"] = True
     print("ā
 Model uploaded successfully!")
 except Exception as e:
@@ -150,8 +173,8 @@ def formatting_prompts_func(examples):
     print("\n" + "=" * 80)
     print("=== VERIFYING REPO CONTENTS ===".center(80))
     print("=" * 80 + "\n")
-    fs = HfFileSystem(token=hf_token)
-    file_list = fs.ls(repo_name, detail=True)
+    fs = HfFileSystem(token = hf_token)
+    file_list = fs.ls(repo_name, detail = True)
     safetensors_found = any(
         file["name"].endswith("model.safetensors.index.json") for file in file_list
     )
@@ -172,7 +195,9 @@ def formatting_prompts_func(examples):
     print("=== TESTING MODEL DOWNLOAD ===".center(80))
     print("=" * 80 + "\n")
     # Force download even if cached
-    model,tokenizer = FastLanguageModel.from_pretrained(f"{hf_username}/merged_llama_text_model")
+    model, tokenizer = FastLanguageModel.from_pretrained(
+        f"{hf_username}/merged_llama_text_model"
+    )
     success["download"] = True
     print("ā
 Model downloaded successfully!")
 except Exception as e:
diff --git a/tests/saving/language_models/test_save_merged_grpo_model.py b/tests/saving/language_models/test_save_merged_grpo_model.py
index 0bbb7ffd4..1f0b377a3 100644
--- a/tests/saving/language_models/test_save_merged_grpo_model.py
+++ b/tests/saving/language_models/test_save_merged_grpo_model.py
@@ -20,46 +20,47 @@
 from tests.utils.aime_eval import evaluate_model_aime, compare_aime_results
 
 
-max_seq_length = 2048 # Can increase for longer reasoning traces
-lora_rank = 64 # Larger rank = smarter, but slower
+max_seq_length = 2048  # Can increase for longer reasoning traces
+lora_rank = 64  # Larger rank = smarter, but slower
 
 
-def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
+def evaluate_merged_model(result_queue, load_in_4bit = False, load_in_8bit = False):
     from unsloth import FastLanguageModel
     from tests.utils.aime_eval import evaluate_model_aime
-    max_seq_length = 2048 # Can increase for longer reasoning traces
-    lora_rank = 64 # Larger rank = smarter, but slower
+
+    max_seq_length = 2048  # Can increase for longer reasoning traces
+    lora_rank = 64  # Larger rank = smarter, but slower
 
     model, tokenizer = FastLanguageModel.from_pretrained(
         model_name = "./final_merged_model",
         max_seq_length = max_seq_length,
-        load_in_4bit = True, # False for LoRA 16bit
-        fast_inference = True, # Enable vLLM fast inference
+        load_in_4bit = True,  # False for LoRA 16bit
+        fast_inference = True,  # Enable vLLM fast inference
         max_lora_rank = lora_rank,
-        gpu_memory_utilization = 0.8, # Reduce if out of memory
+        gpu_memory_utilization = 0.8,  # Reduce if out of memory
     )
 
-    print(f"\n{'='*60}")
+    print(f"\n{'=' * 60}")
     if load_in_4bit:
         print("š EVALUATION Merged model: 4 bits load")
-        model_type="merged_model_4bits"
+        model_type = "merged_model_4bits"
     elif load_in_8bit:
         print("š EVALUATION Merged model: 8 bits load")
-        model_type="merged_model_8bits"
+        model_type = "merged_model_8bits"
     else:
         print("š EVALUATION Merged model: 16 bits load")
-        model_type="merged_model_16bits"
-    print(f"{'='*60}")
+        model_type = "merged_model_16bits"
+    print(f"{'=' * 60}")
 
     evaluate_model_aime(
-        model=model,
-        tokenizer=tokenizer,
-        model_type=model_type,
-        temperature=0.3,
-        n_sampling=8,
-        max_tokens=32768,
-        top_p=0.95,
-        seed=0
+        model = model,
+        tokenizer = tokenizer,
+        model_type = model_type,
+        temperature = 0.3,
+        n_sampling = 8,
+        max_tokens = 32768,
+        top_p = 0.95,
+        seed = 0,
     )
 
     result_queue.put(results)
@@ -70,16 +71,15 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
     gc.collect()
 
 
-
 # Main execution code should be wrapped in this guard
 def training_run(result_queue):
     model, tokenizer = FastLanguageModel.from_pretrained(
         model_name = "meta-llama/Llama-3.2-3B-Instruct",
         max_seq_length = max_seq_length,
-        load_in_4bit = False, # False for LoRA 16bit
-        fast_inference = True, # Enable vLLM fast inference
+        load_in_4bit = False,  # False for LoRA 16bit
+        fast_inference = True,  # Enable vLLM fast inference
         max_lora_rank = lora_rank,
-        gpu_memory_utilization = 0.8, # Reduce if out of memory
+        gpu_memory_utilization = 0.8,  # Reduce if out of memory
     )
 
     """### Helper Functions
@@ -149,7 +149,7 @@ def format_limo(example):
                 "prompt": [  # ā This is the key change - wrap in a dict
                     {"role": "system", "content": system_prompt},
                     {"role": "user", "content": example["question"]},
-                    {"role": "assistant", "content": assistant_response}
+                    {"role": "assistant", "content": assistant_response},
                 ]
             }
 
@@ -166,22 +166,22 @@ def get_max_prompt_length(dataset, tokenizer):
         lengths = dataset.map(
             lambda x: {
                 "tokens": tokenizer.apply_chat_template(
-                    x["prompt"],
-                    add_generation_prompt=True,
-                    tokenize=True
+                    x["prompt"], add_generation_prompt = True, tokenize = True
                 )
             },
-            batched=True,
+            batched = True,
         ).map(lambda x: {"length": len(x["tokens"])})["length"]
 
         max_length = max(lengths)
         avg_length = sum(lengths) / len(lengths)
         min_length = min(lengths)
 
-        print(f"Prompt lengths - Min: {min_length}, Max: {max_length}, Avg: {avg_length:.1f}")
+        print(
+            f"Prompt lengths - Min: {min_length}, Max: {max_length}, Avg: {avg_length:.1f}"
+        )
         return max_length, avg_length
 
-    def extract_unsloth_answer(text, start_tag="", end_tag=""):
+    def extract_unsloth_answer(text, start_tag = "", end_tag = ""):
         """Extract answer from Unsloth SOLUTION tags"""
         pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
         matches = re.findall(pattern, text, re.DOTALL)
@@ -213,10 +213,10 @@ def get_num_tokens(text, tokenizer_instance):
         """Count tokens in text"""
         if not text:
             return 0
-        encoding = tokenizer_instance(text, return_tensors="pt")
+        encoding = tokenizer_instance(text, return_tensors = "pt")
         return len(encoding["input_ids"][0])
 
-    def check_format_compliance(text, format_type="unsloth"):
+    def check_format_compliance(text, format_type = "unsloth"):
         """Check if response follows expected format"""
         if format_type == "unsloth":
             reasoning_start = ""
@@ -265,7 +265,9 @@ def evaluate_answer_correctness(extracted_answer, ground_truth):
             ground_truth_num = float(norm_ground_truth)
 
             if ground_truth_num != 0:
-                relative_error = abs(extracted_num - ground_truth_num) / abs(ground_truth_num)
+                relative_error = abs(extracted_num - ground_truth_num) / abs(
+                    ground_truth_num
+                )
 
                 if relative_error < 0.01:
                     return True, True, 0.9
@@ -300,7 +302,10 @@ def match_format_exactly(completions, **kwargs):
         )
 
         responses = [completion[0]["content"] for completion in completions]
-        rewards = [3.0 if re.match(pattern, response, re.DOTALL) else 0.0 for response in responses]
+        rewards = [
+            3.0 if re.match(pattern, response, re.DOTALL) else 0.0
+            for response in responses
+        ]
         return rewards
 
     def match_format_approximately(completions, **kwargs):
@@ -323,6 +328,7 @@ def match_format_approximately(completions, **kwargs):
 
     def check_answer_correctness(prompts, completions, answer, **kwargs):
         """Reward function for answer correctness"""
+
         def extract_solution_answer(text):
             pattern = r"(.*?)"
             match = re.search(pattern, text, re.DOTALL)
@@ -364,39 +370,47 @@ def extract_solution_answer(text):
 
     import gc
 
-
-
     """#### Comparison and Memory Management"""
 
     def compare_model_results(all_results):
         """Generate comprehensive comparison of multiple model results"""
-        print(f"\n{'='*80}")
+        print(f"\n{'=' * 80}")
         print("COMPREHENSIVE MODEL COMPARISON")
-        print(f"{'='*80}")
+        print(f"{'=' * 80}")
 
         # Main table
-        print(f"{'Model':<15} {'Format %':<10} {'Exact %':<10} {'Plausible %':<12} {'Confidence':<12}")
+        print(
+            f"{'Model':<15} {'Format %':<10} {'Exact %':<10} {'Plausible %':<12} {'Confidence':<12}"
+        )
         print("-" * 80)
 
         for result in all_results:
-            print(f"{result['model_type']:<15} "
-                  f"{result['correct_format_pct']:<10.1f} "
-                  f"{result['exact_match_pct']:<10.1f} "
-                  f"{result['plausible_match_pct']:<12.1f} "
-                  f"{result['avg_confidence']:<12.3f}")
+            print(
+                f"{result['model_type']:<15} "
+                f"{result['correct_format_pct']:<10.1f} "
+                f"{result['exact_match_pct']:<10.1f} "
+                f"{result['plausible_match_pct']:<12.1f} "
+                f"{result['avg_confidence']:<12.3f}"
+            )
 
         # Improvement analysis
         if len(all_results) > 1:
-            print(f"\n{'='*50}")
+            print(f"\n{'=' * 50}")
             print("IMPROVEMENT ANALYSIS")
-            print(f"{'='*50}")
+            print(f"{'=' * 50}")
 
             base_result = all_results[0]
             for result in all_results[1:]:
                 print(f"\n{result['model_type']} vs {base_result['model_type']}:")
-                format_improvement = result['correct_format_pct'] - base_result['correct_format_pct']
-                exact_improvement = result['exact_match_pct'] - base_result['exact_match_pct']
-                plausible_improvement = result['plausible_match_pct'] - base_result['plausible_match_pct']
+                format_improvement = (
+                    result["correct_format_pct"] - base_result["correct_format_pct"]
+                )
+                exact_improvement = (
+                    result["exact_match_pct"] - base_result["exact_match_pct"]
+                )
+                plausible_improvement = (
+                    result["plausible_match_pct"] - base_result["plausible_match_pct"]
+                )
 
                 print(f"  Format compliance: {format_improvement:+.1f}%")
                 print(f"  Exact matches:     {exact_improvement:+.1f}%")
@@ -405,14 +419,16 @@ def compare_model_results(all_results):
         # Save comparison
         comparison_data = {
             "summary": all_results,
-            "best_model": max(all_results, key=lambda x: x['exact_match_pct']),
+            "best_model": max(all_results, key = lambda x: x["exact_match_pct"]),
         }
 
         with open("model_comparison_comprehensive.json", "w") as f:
-            json.dump(comparison_data, f, indent=4)
+            json.dump(comparison_data, f, indent = 4)
 
-        print(f"\nBest performing model: {comparison_data['best_model']['model_type']} "
-              f"({comparison_data['best_model']['exact_match_pct']:.1f}% exact matches)")
+        print(
+            f"\nBest performing model: {comparison_data['best_model']['model_type']} "
+            f"({comparison_data['best_model']['exact_match_pct']:.1f}% exact matches)"
+        )
 
     def cleanup_memory():
         """Comprehensive memory cleanup"""
@@ -424,40 +440,40 @@ def cleanup_memory():
         if torch.cuda.is_available():
             allocated = torch.cuda.memory_allocated() / 1024**3
             reserved = torch.cuda.memory_reserved() / 1024**3
-            print(f"GPU memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
+            print(
+                f"GPU memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB"
+            )
 
     """#### Data Loading and Preparation"""
 
     from datasets import load_dataset
 
+    # Load GSM8K
+    gsm8k_dataset = load_dataset("openai/gsm8k", "main", split = "train")
 
-# Load GSM8K
-    gsm8k_dataset = load_dataset("openai/gsm8k", "main", split="train")
-
-# Load LIMO (adjust this based on your access method)
-    limo_train = load_dataset("GAIR/LIMO", split="train")
+    # Load LIMO (adjust this based on your access method)
+    limo_train = load_dataset("GAIR/LIMO", split = "train")
 
-# Prepare datasets
+    # Prepare datasets
     gsm8k_train = prepare_gsm8k_dataset(gsm8k_dataset)
     limo_train = prepare_limo_dataset(limo_train)
 
-
     print(f"  GSM8K train: {len(gsm8k_train)}")
     print(f"  LIMO train:  {len(limo_train) if limo_train else 0}")
 
-# Store results
+    # Store results
     all_results = []
 
-# Single temperature evaluation on combined dataset
+    # Single temperature evaluation on combined dataset
     results = evaluate_model_aime(
-        model=model,
-        tokenizer=tokenizer,
-        model_type="base",
-        temperature=0.3,
-        n_sampling=8,
-        max_tokens=32768,
-        top_p=0.95,
-        seed=0
+        model = model,
+        tokenizer = tokenizer,
+        model_type = "base",
+        temperature = 0.3,
+        n_sampling = 8,
+        max_tokens = 32768,
+        top_p = 0.95,
+        seed = 0,
     )
 
     from unsloth.chat_templates import get_chat_template
@@ -469,69 +485,81 @@ def cleanup_memory():
 
     def formatting_prompts_func(examples):
         convos = examples["prompt"]
-        texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
-        return { "text" : texts, }
-    pass
+        texts = [
+            tokenizer.apply_chat_template(
+                convo, tokenize = False, add_generation_prompt = False
+            )
+            for convo in convos
+        ]
+        return {
+            "text": texts,
+        }
 
-    limo_train = limo_train.map(formatting_prompts_func, batched = True,)
+    limo_train = limo_train.map(
+        formatting_prompts_func,
+        batched = True,
+    )
 
     from trl import SFTTrainer
     from transformers import DataCollatorForSeq2Seq, TrainingArguments
     from unsloth import is_bfloat16_supported
 
-
-    print(f"\n{'*'*60}")
+    print(f"\n{'*' * 60}")
     print("šÆ STAGE 1: Qlora Fine-Tuning on LIMO")
-    print(f"{'*'*60}")
+    print(f"{'*' * 60}")
 
     model = FastLanguageModel.get_peft_model(
         model,
-        r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+        r = lora_rank,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
         target_modules = [
-            "q_proj", "k_proj", "v_proj", "o_proj",
-            "gate_proj", "up_proj", "down_proj",
-        ], # Remove QKVO if out of memory
+            "q_proj",
+            "k_proj",
+            "v_proj",
+            "o_proj",
+            "gate_proj",
+            "up_proj",
+            "down_proj",
+        ],  # Remove QKVO if out of memory
         lora_alpha = lora_rank,
-        use_gradient_checkpointing = "unsloth", # Enable long context finetuning
+        use_gradient_checkpointing = "unsloth",  # Enable long context finetuning
         random_state = 3407,
     )
 
-
     if limo_train is not None:
         trainer = SFTTrainer(
-        model = model,
-        tokenizer = tokenizer,
-        train_dataset = limo_train,
-        dataset_text_field = "text",
-        max_seq_length = max_seq_length,
-        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
-        dataset_num_proc = 2,
-        packing = False, # Can make training 5x faster for short sequences.
-        args = TrainingArguments(
-            per_device_train_batch_size = 2,
-            gradient_accumulation_steps = 4,
-            warmup_steps = 5,
-            num_train_epochs = 1, # Set this for 1 full training run.
-            #max_steps = 60,
-            learning_rate = 2e-4,
-            fp16 = not is_bfloat16_supported(),
-            bf16 = is_bfloat16_supported(),
-            logging_steps = 1,
-            optim = "adamw_8bit",
-            weight_decay = 0.01,
-            lr_scheduler_type = "linear",
-            seed = 3407,
-            output_dir = "outputs",
-            report_to = "none", # Use this for WandB etc
-        ),
-    )
-
+            model = model,
+            tokenizer = tokenizer,
+            train_dataset = limo_train,
+            dataset_text_field = "text",
+            max_seq_length = max_seq_length,
+            data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
+            dataset_num_proc = 2,
+            packing = False,  # Can make training 5x faster for short sequences.
+            args = TrainingArguments(
+                per_device_train_batch_size = 2,
+                gradient_accumulation_steps = 4,
+                warmup_steps = 5,
+                num_train_epochs = 1,  # Set this for 1 full training run.
+                # max_steps = 60,
+                learning_rate = 2e-4,
+                fp16 = not is_bfloat16_supported(),
+                bf16 = is_bfloat16_supported(),
+                logging_steps = 1,
+                optim = "adamw_8bit",
+                weight_decay = 0.01,
+                lr_scheduler_type = "linear",
+                seed = 3407,
+                output_dir = "outputs",
+                report_to = "none",  # Use this for WandB etc
+            ),
+        )
 
         from unsloth.chat_templates import train_on_responses_only
+
         trainer = train_on_responses_only(
-        trainer,
-        instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
-        response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
+            trainer,
+            instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
+            response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
         )
 
         # Train
@@ -551,7 +579,7 @@ def formatting_prompts_func(examples):
     else:
         print("ā ļø Skipping Qlora training - no LIMO dataset available")
 
-# Cleanup
+    # Cleanup
     cleanup_memory()
 
     global PRINTED_TIMES
@@ -560,8 +588,7 @@ def formatting_prompts_func(examples):
     PRINT_EVERY_STEPS = 5
 
     match_numbers = re.compile(
-        solution_start + r".*?([\d\.\,]{1,})",
-        flags = re.MULTILINE | re.DOTALL
+        solution_start + r".*?([\d\.\,]{1,})", flags = re.MULTILINE | re.DOTALL
     )
 
     def check_numbers(prompts, completions, answer, **kwargs):
@@ -569,8 +596,7 @@ def check_numbers(prompts, completions, answer, **kwargs):
         responses = [completion[0]["content"] for completion in completions]
 
         extracted_responses = [
-            guess.group(1)
-            if (guess := match_numbers.search(r)) is not None else None \
+            guess.group(1) if (guess := match_numbers.search(r)) is not None else None
             for r in responses
         ]
 
@@ -579,7 +605,13 @@ def check_numbers(prompts, completions, answer, **kwargs):
         global PRINTED_TIMES
         global PRINT_EVERY_STEPS
         if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
-            print('*'*20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
+            print(
+                "*" * 20,
+                f"Question:\n{question}",
+                f"\nAnswer:\n{answer[0]}",
+                f"\nResponse:\n{responses[0]}",
+                f"\nExtracted:\n{extracted_responses[0]}",
+            )
         PRINTED_TIMES += 1
 
         for guess, true_answer in zip(extracted_responses, answer):
@@ -590,24 +622,25 @@ def check_numbers(prompts, completions, answer, **kwargs):
             try:
                 true_answer = float(true_answer.strip())
                 # Remove commas like in 123,456
-                guess       = float(guess.strip().replace(",", ""))
+                guess = float(guess.strip().replace(",", ""))
                 scores.append(1.5 if guess == true_answer else -0.5)
             except:
                 scores.append(0)
                 continue
         return scores
 
-    print(f"\n{'*'*60}")
+    print(f"\n{'*' * 60}")
     print("šÆ STAGE 2: GRPO Fine-Tuning on GSM8K")
-    print(f"{'*'*60}")
+    print(f"{'*' * 60}")
 
-# Get max prompt length
+    # Get max prompt length
     max_prompt_length, _ = get_max_prompt_length(gsm8k_train, tokenizer)
     max_prompt_length = min(max_prompt_length + 10, 512)  # Add buffer, cap at 512
 
     print(f"Using max_prompt_length: {max_prompt_length}")
 
     from trl import GRPOConfig, GRPOTrainer
+
     training_args = GRPOConfig(
         learning_rate = 5e-6,
         weight_decay = 0.1,
@@ -616,16 +649,16 @@ def check_numbers(prompts, completions, answer, **kwargs):
         optim = "adamw_torch_fused",
         logging_steps = 1,
         per_device_train_batch_size = 1,
-        gradient_accumulation_steps = 4, # Increase to 4 for smoother training
-        num_generations = 8, # Decrease if out of memory
+        gradient_accumulation_steps = 4,  # Increase to 4 for smoother training
+        num_generations = 8,  # Decrease if out of memory
         max_prompt_length = max_prompt_length,
         max_completion_length = max_seq_length - max_prompt_length,
         # num_train_epochs = 1, # Set to 1 for a full training run
-        #max_steps = 250,
+        # max_steps = 250,
         max_steps = 1000,
         save_steps = 250,
         max_grad_norm = 0.1,
-        report_to = "none", # Can use Weights & Biases
+        report_to = "none",  # Can use Weights & Biases
         output_dir = "outputs",
     )
 
@@ -642,48 +675,49 @@ def check_numbers(prompts, completions, answer, **kwargs):
         train_dataset = gsm8k_train,
     )
 
-
-# Train
+    # Train
     print(f"š Starting GRPO training on {len(gsm8k_train)} examples...")
     trainer.train()
 
-# Save checkpoint
+    # Save checkpoint
     model.save_pretrained("grpo_checkpoint")
     tokenizer.save_pretrained("grpo_checkpoint")
     print("š¾ GRPO checkpoint saved!")
 
-# Cleanup
+    # Cleanup
     del trainer
     del training_args
     cleanup_memory()
 
     print("ā
 GRPO training completed!")
 
-    print(f"\n{'='*60}")
+    print(f"\n{'=' * 60}")
     print("š EVALUATION 3: Final GRPO Model")
-    print(f"{'='*60}")
+    print(f"{'=' * 60}")
 
     grpo_results = evaluate_model_aime(
-        model=model,
-        tokenizer=tokenizer,
-        model_type="grpo",
-        temperature=0.3,
-        n_sampling=8,
-        max_tokens=32768,
-        top_p=0.95,
-        seed=0
+        model = model,
+        tokenizer = tokenizer,
+        model_type = "grpo",
+        temperature = 0.3,
+        n_sampling = 8,
+        max_tokens = 32768,
+        top_p = 0.95,
+        seed = 0,
     )
 
     all_results.append(grpo_results)
     print("ā
 Final model evaluation complete!")
 
-    print(f"\n{'='*60}")
+    print(f"\n{'=' * 60}")
     print("š¾ SAVING FINAL MODEL")
-    print(f"{'='*60}")
+    print(f"{'=' * 60}")
 
     # Save as merged model
     try:
-        model.save_pretrained_merged("final_merged_model", tokenizer, save_method="merged_16bit")
+        model.save_pretrained_merged(
+            "final_merged_model", tokenizer, save_method = "merged_16bit"
+        )
         print("ā
 Merged model saved to: final_merged_model/")
     except Exception as e:
         print(f"ā ļø Could not save merged model: {e}")
@@ -691,7 +725,6 @@ def check_numbers(prompts, completions, answer, **kwargs):
 
     print("š¾ Model saving complete!")
 
-
     safe_remove_directory("./unsloth_compiled_cache")
 
     result_queue.put(results)
@@ -702,7 +735,6 @@ def check_numbers(prompts, completions, answer, **kwargs):
     torch.cuda.empty_cache()
     gc.collect()
 
-
     # # Merged model load 16 bits model AIME eval
     # result_queue = mp.Queue()
     # p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, False))
@@ -731,7 +763,6 @@ def check_numbers(prompts, completions, answer, **kwargs):
     # merged_16bits = result_queue.get()
     # all_results.append(merged_16bits)
 
-
     # Merged model load 4 bits AIME eval
     # result_queue = mp.Queue()
     # p = mp.Process(target=evaluate_merged_model, args=(result_queue, True, False))
@@ -741,14 +772,14 @@ def check_numbers(prompts, completions, answer, **kwargs):
     # merged_16bits = result_queue.get()
     # all_results.append(merged_16bits)
 
+
 if __name__ == "__main__":
-    mp.set_start_method('spawn', force=True)
+    mp.set_start_method("spawn", force = True)
     result_queue = mp.Queue()
     all_results = []
 
-
     # run main finetuning and grpo loop
-    p = mp.Process(target=training_run, args=(result_queue,))
+    p = mp.Process(target = training_run, args = (result_queue,))
     p.start()
     p.join()
 
@@ -756,7 +787,7 @@ def check_numbers(prompts, completions, answer, **kwargs):
     all_results = results
 
     # evaluate merged model loaded 16bits
-    p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, False))
+    p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, False))
     p.start()
     p.join()
 
@@ -765,7 +796,7 @@ def check_numbers(prompts, completions, answer, **kwargs):
     safe_remove_directory("./unsloth_compiled_cache")
 
     # Merged model load 8 bits model AIME eval
-    p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, True))
+    p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, True))
     p.start()
     p.join()
 
@@ -775,7 +806,7 @@ def check_numbers(prompts, completions, answer, **kwargs):
     safe_remove_directory("./unsloth_compiled_cache")
 
     # Merged model load 4 bits model AIME eval
-    p = mp.Process(target=evaluate_merged_model, args=(result_queue, True, False))
+    p = mp.Process(target = evaluate_merged_model, args = (result_queue, True, False))
     p.start()
     p.join()
 
@@ -784,12 +815,11 @@ def check_numbers(prompts, completions, answer, **kwargs):
 
     safe_remove_directory("./unsloth_compiled_cache")
 
+    # AIME-specific comparison function
 
-# AIME-specific comparison function
-
-    print(f"\n{'='*80}")
+    print(f"\n{'=' * 80}")
     print("š FINAL TRAINING PIPELINE RESULTS")
-    print(f"{'='*80}")
+    print(f"{'=' * 80}")
 
-# Use the AIME-specific comparison
+    # Use the AIME-specific comparison
     compare_aime_results(all_results)
diff --git a/tests/saving/non_peft/test_mistral_non_peft.py b/tests/saving/non_peft/test_mistral_non_peft.py
index 730815130..755e19f30 100644
--- a/tests/saving/non_peft/test_mistral_non_peft.py
+++ b/tests/saving/non_peft/test_mistral_non_peft.py
@@ -11,18 +11,18 @@
 from tests.utils.cleanup_utils import safe_remove_directory
 
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š PHASE 1: Loading Base Model")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 model, tokenizer = FastLanguageModel.from_pretrained(
-        model_name="unsloth/mistral-7b-v0.3",
-        max_seq_length=2048,
-        dtype=None,
-        load_in_4bit=True,
-        load_in_8bit=False,
-        full_finetuning=False,
-    )
+    model_name = "unsloth/mistral-7b-v0.3",
+    max_seq_length = 2048,
+    dtype = None,
+    load_in_4bit = True,
+    load_in_8bit = False,
+    full_finetuning = False,
+)
 
 
 print("ā
 Base model loaded successfully!")
@@ -30,29 +30,27 @@
 ### Attemtping save merge
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š PHASE 2: Attempting save_pretrained_merged (Should Warn)")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
-with warnings.catch_warnings(record=True) as w:
-        warnings.simplefilter("always")
-        model.save_pretrained_merged("test_output", tokenizer)
+with warnings.catch_warnings(record = True) as w:
+    warnings.simplefilter("always")
+    model.save_pretrained_merged("test_output", tokenizer)
 
-        # Verify warning
-        assert len(w) >= 1, "Expected warning but none raised"
-        warning_msg = str(w[0].message)
-        expected_msg = "Model is not a PeftModel (no Lora adapters detected). Skipping Merge. Please use save_pretrained() or push_to_hub() instead!"
-        assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
-        assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
+    # Verify warning
+    assert len(w) >= 1, "Expected warning but none raised"
+    warning_msg = str(w[0].message)
+    expected_msg = "Model is not a PeftModel (no Lora adapters detected). Skipping Merge. Please use save_pretrained() or push_to_hub() instead!"
+    assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
+    assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
 
 print("ā
 Correct warning detected for non-PeftModel merge attempt!")
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š PHASE 3: Using save_pretrained (Should Succeed)")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 try:
diff --git a/tests/saving/non_peft/test_whisper_non_peft.py b/tests/saving/non_peft/test_whisper_non_peft.py
index 40321d29b..c72bd38e9 100644
--- a/tests/saving/non_peft/test_whisper_non_peft.py
+++ b/tests/saving/non_peft/test_whisper_non_peft.py
@@ -11,14 +11,14 @@
 from tests.utils.cleanup_utils import safe_remove_directory
 
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š PHASE 1: Loading Base Model")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 model, tokenizer = FastModel.from_pretrained(
     model_name = "unsloth/whisper-large-v3",
-    dtype = None, # Leave as None for auto detection
-    load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
+    dtype = None,  # Leave as None for auto detection
+    load_in_4bit = False,  # Set to True to do 4bit quantization which reduces memory
     auto_model = WhisperForConditionalGeneration,
     whisper_language = "English",
     whisper_task = "transcribe",
@@ -30,29 +30,27 @@
 ### Attemtping save merge
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š PHASE 2: Attempting save_pretrained_merged (Should Warn)")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
-with warnings.catch_warnings(record=True) as w:
-        warnings.simplefilter("always")
-        model.save_pretrained_merged("test_output", tokenizer)
+with warnings.catch_warnings(record = True) as w:
+    warnings.simplefilter("always")
+    model.save_pretrained_merged("test_output", tokenizer)
 
-        # Verify warning
-        assert len(w) >= 1, "Expected warning but none raised"
-        warning_msg = str(w[0].message)
-        expected_msg = "Model is not a PeftModel (no Lora adapters detected). Skipping Merge. Please use save_pretrained() or push_to_hub() instead!"
-        assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
-        assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
+    # Verify warning
+    assert len(w) >= 1, "Expected warning but none raised"
+    warning_msg = str(w[0].message)
+    expected_msg = "Model is not a PeftModel (no Lora adapters detected). Skipping Merge. Please use save_pretrained() or push_to_hub() instead!"
+    assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
+    assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
 
 print("ā
 Correct warning detected for non-PeftModel merge attempt!")
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š PHASE 3: Using save_pretrained (Should Succeed)")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 try:
diff --git a/tests/saving/test_unsloth_save.py b/tests/saving/test_unsloth_save.py
index 54078f36b..35fdad6ba 100644
--- a/tests/saving/test_unsloth_save.py
+++ b/tests/saving/test_unsloth_save.py
@@ -19,14 +19,14 @@
     # Vision Models
     "unsloth/gemma-3-4b-it",
     "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
-    "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit"
+    "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit",
 ]
 
 torchao_models = [
     "unsloth/tinyllama",
     "unsloth/Qwen2.5-0.5B-Instruct",
-    #"unsloth/Phi-4-mini-instruct",
-    #"unsloth/Qwen2.5-0.5B",
+    # "unsloth/Phi-4-mini-instruct",
+    # "unsloth/Qwen2.5-0.5B",
     # Skip the -bnb-4bit variants since they're already quantized
 ]
 
@@ -42,30 +42,32 @@
     "special_tokens_map.json",
 ]
 
-@pytest.fixture(scope="session", params=model_to_test)
+
+@pytest.fixture(scope = "session", params = model_to_test)
 def loaded_model_tokenizer(request):
     model_name = request.param
     print("Loading model and tokenizer...")
 
     model, tokenizer = FastModel.from_pretrained(
-        model_name, # use small model
-        max_seq_length=128,
-        dtype=None,
-        load_in_4bit=True,
+        model_name,  # use small model
+        max_seq_length = 128,
+        dtype = None,
+        load_in_4bit = True,
     )
 
     # Apply LoRA
     model = FastModel.get_peft_model(
         model,
-        r=16,
-        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
-        lora_alpha=16,
-        use_gradient_checkpointing="unsloth",
+        r = 16,
+        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
+        lora_alpha = 16,
+        use_gradient_checkpointing = "unsloth",
     )
 
     return model, tokenizer
 
-@pytest.fixture(scope="session", params=torchao_models)
+
+@pytest.fixture(scope = "session", params = torchao_models)
 def fp16_model_tokenizer(request):
     """Load model in FP16 for TorchAO quantization"""
     model_name = request.param
@@ -73,32 +75,33 @@ def fp16_model_tokenizer(request):
 
     model, tokenizer = FastModel.from_pretrained(
         model_name,
-        max_seq_length=128,
-        dtype=None,
-        load_in_4bit=False,  # No BnB quantization
+        max_seq_length = 128,
+        dtype = None,
+        load_in_4bit = False,  # No BnB quantization
     )
 
     # Apply LoRA
     model = FastModel.get_peft_model(
         model,
-        r=16,
-        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
-        lora_alpha=16,
-        use_gradient_checkpointing="unsloth",
+        r = 16,
+        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
+        lora_alpha = 16,
+        use_gradient_checkpointing = "unsloth",
     )
 
     return model, tokenizer
 
 
-
-@pytest.fixture(scope="session")
+@pytest.fixture(scope = "session")
 def model(loaded_model_tokenizer):
     return loaded_model_tokenizer[0]
 
-@pytest.fixture(scope="session")
+
+@pytest.fixture(scope = "session")
 def tokenizer(loaded_model_tokenizer):
     return loaded_model_tokenizer[1]
 
+
 @pytest.fixture
 def temp_save_dir():
     dir = tempfile.mkdtemp()
@@ -121,32 +124,45 @@ def delete_quantization_config(model):
         original_model.config = new_config
     model.config = new_config
 
+
 def test_save_merged_16bit(model, tokenizer, temp_save_dir: str):
-    save_path = os.path.join(temp_save_dir, "unsloth_merged_16bit", model.config._name_or_path.replace("/", "_"))
+    save_path = os.path.join(
+        temp_save_dir,
+        "unsloth_merged_16bit",
+        model.config._name_or_path.replace("/", "_"),
+    )
 
     model.save_pretrained_merged(
-        save_path,
-        tokenizer=tokenizer,
-        save_method="merged_16bit"
+        save_path, tokenizer = tokenizer, save_method = "merged_16bit"
     )
 
     # Check model files
     assert os.path.isdir(save_path), f"Directory {save_path} does not exist."
-    assert os.path.isfile(os.path.join(save_path, "config.json")), "config.json not found."
-
-    weight_files = [f for f in os.listdir(save_path) if f.endswith(".bin") or f.endswith(".safetensors")]
+    assert os.path.isfile(
+        os.path.join(save_path, "config.json")
+    ), "config.json not found."
+
+    weight_files = [
+        f
+        for f in os.listdir(save_path)
+        if f.endswith(".bin") or f.endswith(".safetensors")
+    ]
     assert len(weight_files) > 0, "No weight files found in the save directory."
 
     # Check tokenizer files
     for file in tokenizer_files:
-        assert os.path.isfile(os.path.join(save_path, file)), f"{file} not found in the save directory."
+        assert os.path.isfile(
+            os.path.join(save_path, file)
+        ), f"{file} not found in the save directory."
 
     # Check config to see if it is 16bit by checking for quantization config
     config_path = os.path.join(save_path, "config.json")
     with open(config_path, "r") as f:
         config = json.load(f)
 
-    assert "quantization_config" not in config, "Quantization config not found in the model config."
+    assert (
+        "quantization_config" not in config
+    ), "Quantization config not found in the model config."
 
     # Store the size of the model files
     total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files)
@@ -156,30 +172,41 @@ def test_save_merged_16bit(model, tokenizer, temp_save_dir: str):
     # Test loading the model from the saved path
     loaded_model, loaded_tokenizer = FastLanguageModel.from_pretrained(
         save_path,
-        max_seq_length=128,
-        dtype=None,
-        load_in_4bit=True,
+        max_seq_length = 128,
+        dtype = None,
+        load_in_4bit = True,
     )
 
+
 def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
-    save_path = os.path.join(temp_save_dir, "unsloth_merged_4bit", model.config._name_or_path.replace("/", "_"))
+    save_path = os.path.join(
+        temp_save_dir,
+        "unsloth_merged_4bit",
+        model.config._name_or_path.replace("/", "_"),
+    )
 
     model.save_pretrained_merged(
-        save_path,
-        tokenizer=tokenizer,
-        save_method="merged_4bit_forced"
+        save_path, tokenizer = tokenizer, save_method = "merged_4bit_forced"
     )
 
     # Check model files
     assert os.path.isdir(save_path), f"Directory {save_path} does not exist."
-    assert os.path.isfile(os.path.join(save_path, "config.json")), "config.json not found."
-
-    weight_files = [f for f in os.listdir(save_path) if f.endswith(".bin") or f.endswith(".safetensors")]
+    assert os.path.isfile(
+        os.path.join(save_path, "config.json")
+    ), "config.json not found."
+
+    weight_files = [
+        f
+        for f in os.listdir(save_path)
+        if f.endswith(".bin") or f.endswith(".safetensors")
+    ]
     assert len(weight_files) > 0, "No weight files found in the save directory."
 
     # Check tokenizer files
     for file in tokenizer_files:
-        assert os.path.isfile(os.path.join(save_path, file)), f"{file} not found in the save directory."
+        assert os.path.isfile(
+            os.path.join(save_path, file)
+        ), f"{file} not found in the save directory."
 
     # Store the size of the model files
     total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files)
@@ -187,141 +214,187 @@ def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
 
     print(f"Total size of merged_4bit files: {total_size} bytes")
 
-    assert total_size < save_file_sizes["merged_16bit"][model.config._name_or_path], "Merged 4bit files are larger than merged 16bit files."
+    assert (
+        total_size < save_file_sizes["merged_16bit"][model.config._name_or_path]
+    ), "Merged 4bit files are larger than merged 16bit files."
 
     # Check config to see if it is 4bit
     config_path = os.path.join(save_path, "config.json")
     with open(config_path, "r") as f:
         config = json.load(f)
 
-    assert "quantization_config" in config, "Quantization config not found in the model config."
+    assert (
+        "quantization_config" in config
+    ), "Quantization config not found in the model config."
 
     # Test loading the model from the saved path
     loaded_model, loaded_tokenizer = FastModel.from_pretrained(
         save_path,
-        max_seq_length=128,
-        dtype=None,
-        load_in_4bit=True,
+        max_seq_length = 128,
+        dtype = None,
+        load_in_4bit = True,
     )
 
-@pytest.mark.skipif(importlib.util.find_spec("torchao") is None, reason="require torchao to be installed")
+
+@pytest.mark.skipif(
+    importlib.util.find_spec("torchao") is None,
+    reason = "require torchao to be installed",
+)
 def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):
     model, tokenizer = fp16_model_tokenizer
-    save_path = os.path.join(temp_save_dir, "unsloth_torchao", model.config._name_or_path.replace("/", "_"))
+    save_path = os.path.join(
+        temp_save_dir, "unsloth_torchao", model.config._name_or_path.replace("/", "_")
+    )
 
     from torchao.quantization import Int8DynamicActivationInt8WeightConfig
+
     torchao_config = Int8DynamicActivationInt8WeightConfig()
     model.save_pretrained_torchao(
         save_path,
-        tokenizer=tokenizer,
-        torchao_config=torchao_config,
-        push_to_hub=False,
+        tokenizer = tokenizer,
+        torchao_config = torchao_config,
+        push_to_hub = False,
     )
 
-    weight_files_16bit = [f for f in os.listdir(save_path) if f.endswith(".bin") or f.endswith(".safetensors")]
-    total_16bit_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files_16bit)
+    weight_files_16bit = [
+        f
+        for f in os.listdir(save_path)
+        if f.endswith(".bin") or f.endswith(".safetensors")
+    ]
+    total_16bit_size = sum(
+        os.path.getsize(os.path.join(save_path, f)) for f in weight_files_16bit
+    )
     save_file_sizes["merged_16bit"][model.config._name_or_path] = total_16bit_size
 
     torchao_save_path = save_path + "-torchao"
 
     # Check model files
-    assert os.path.isdir(torchao_save_path), f"Directory {torchao_save_path} does not exist."
-    assert os.path.isfile(os.path.join(torchao_save_path, "config.json")), "config.json not found."
-
-    weight_files = [f for f in os.listdir(torchao_save_path) if f.endswith(".bin") or f.endswith(".safetensors")]
+    assert os.path.isdir(
+        torchao_save_path
+    ), f"Directory {torchao_save_path} does not exist."
+    assert os.path.isfile(
+        os.path.join(torchao_save_path, "config.json")
+    ), "config.json not found."
+
+    weight_files = [
+        f
+        for f in os.listdir(torchao_save_path)
+        if f.endswith(".bin") or f.endswith(".safetensors")
+    ]
     assert len(weight_files) > 0, "No weight files found in the save directory."
 
     # Check tokenizer files
     for file in tokenizer_files:
-        assert os.path.isfile(os.path.join(torchao_save_path, file)), f"{file} not found in the save directory."
+        assert os.path.isfile(
+            os.path.join(torchao_save_path, file)
+        ), f"{file} not found in the save directory."
 
     # Store the size of the model files
-    total_size = sum(os.path.getsize(os.path.join(torchao_save_path, f)) for f in weight_files)
+    total_size = sum(
+        os.path.getsize(os.path.join(torchao_save_path, f)) for f in weight_files
+    )
     save_file_sizes["torchao"][model.config._name_or_path] = total_size
 
-    assert total_size < save_file_sizes["merged_16bit"][model.config._name_or_path], "torchao files are larger than merged 16bit files."
+    assert (
+        total_size < save_file_sizes["merged_16bit"][model.config._name_or_path]
+    ), "torchao files are larger than merged 16bit files."
 
     # Check config to see if it is quantized with torchao
     config_path = os.path.join(torchao_save_path, "config.json")
     with open(config_path, "r") as f:
         config = json.load(f)
 
-    assert "quantization_config" in config, "Quantization config not found in the model config."
+    assert (
+        "quantization_config" in config
+    ), "Quantization config not found in the model config."
 
     # Test loading the model from the saved path
     # can't set `load_in_4bit` to True because the model is torchao quantized
     # can't quantize again with bitsandbytes
     import torch.serialization
+
     with torch.serialization.safe_globals([getattr]):
         loaded_model, loaded_tokenizer = FastModel.from_pretrained(
             torchao_save_path,
-            max_seq_length=128,
-            dtype=None,
-            load_in_4bit=False,
+            max_seq_length = 128,
+            dtype = None,
+            load_in_4bit = False,
         )
 
-@pytest.mark.skipif(importlib.util.find_spec("torchao") is None, reason="require torchao to be installed")
+
+@pytest.mark.skipif(
+    importlib.util.find_spec("torchao") is None,
+    reason = "require torchao to be installed",
+)
 def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
     model, tokenizer = fp16_model_tokenizer
     model_name = model.config._name_or_path
 
     print(f"Testing TorchAO save and inference for: {model_name}")
 
-    save_path = os.path.join(temp_save_dir, "torchao_models", model_name.replace("/", "_"))
+    save_path = os.path.join(
+        temp_save_dir, "torchao_models", model_name.replace("/", "_")
+    )
 
     from torchao.quantization import Int8DynamicActivationInt8WeightConfig
+
     torchao_config = Int8DynamicActivationInt8WeightConfig()
 
     # Save with TorchAO
     model.save_pretrained_torchao(
         save_path,
-        tokenizer=tokenizer,
-        torchao_config=torchao_config,
-        push_to_hub=False,
+        tokenizer = tokenizer,
+        torchao_config = torchao_config,
+        push_to_hub = False,
     )
 
     torchao_save_path = save_path + "-torchao"
 
     # Verify files exist
-    assert os.path.isdir(torchao_save_path), f"TorchAO directory {torchao_save_path} does not exist."
+    assert os.path.isdir(
+        torchao_save_path
+    ), f"TorchAO directory {torchao_save_path} does not exist."
 
     # Load with safe globals
     import torch.serialization
+
     with torch.serialization.safe_globals([getattr]):
         loaded_model, loaded_tokenizer = FastModel.from_pretrained(
             torchao_save_path,
-            max_seq_length=128,
-            dtype=None,
-            load_in_4bit=False,
+            max_seq_length = 128,
+            dtype = None,
+            load_in_4bit = False,
         )
 
-
-    FastModel.for_inference(loaded_model) # Enable native 2x faster inference
+    FastModel.for_inference(loaded_model)  # Enable native 2x faster inference
 
     messages = [
-        {"role": "user", "content": "Continue the fibonnaci sequence: 1, 1, 2, 3, 5, 8,"},
+        {
+            "role": "user",
+            "content": "Continue the fibonnaci sequence: 1, 1, 2, 3, 5, 8,",
+        },
     ]
     inputs = loaded_tokenizer.apply_chat_template(
         messages,
         tokenize = True,
-        add_generation_prompt = True, # Must add for generation
+        add_generation_prompt = True,  # Must add for generation
         return_tensors = "pt",
     ).to("cuda")
 
     outputs = loaded_model.generate(  # ā Use loaded_model, not model
-        input_ids=inputs,
-        max_new_tokens=64,
-        use_cache=False,  # Avoid cache issues
-        temperature=1.5,
-        min_p=0.1,
-        do_sample=True,
-        pad_token_id=loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id,
+        input_ids = inputs,
+        max_new_tokens = 64,
+        use_cache = False,  # Avoid cache issues
+        temperature = 1.5,
+        min_p = 0.1,
+        do_sample = True,
+        pad_token_id = loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id,
     )
 
-    #Decode with the LOADED tokenizer
-    generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens=True)
-    input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens=True)
-    response_part = generated_text[len(input_text):].strip()
+    # Decode with the LOADED tokenizer
+    generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens = True)
+    input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens = True)
+    response_part = generated_text[len(input_text) :].strip()
 
     print(f"Input: {input_text}")
     print(f"Full output: {generated_text}")
diff --git a/tests/saving/text_to_speech_models/test_csm.py b/tests/saving/text_to_speech_models/test_csm.py
index c3703cacf..5f0694208 100644
--- a/tests/saving/text_to_speech_models/test_csm.py
+++ b/tests/saving/text_to_speech_models/test_csm.py
@@ -1,6 +1,7 @@
 from unsloth import FastLanguageModel, FastModel
 from transformers import CsmForConditionalGeneration
 import torch
+
 # ruff: noqa
 import sys
 from pathlib import Path
@@ -19,17 +20,17 @@
 
 import soundfile as sf
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 1: Loading Model and LoRA Adapters")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 model, tokenizer = FastModel.from_pretrained(
     model_name = "unsloth/csm-1b",
-    max_seq_length= 2048, # Choose any for long context!
-    dtype = None, # Leave as None for auto-detection
+    max_seq_length = 2048,  # Choose any for long context!
+    dtype = None,  # Leave as None for auto-detection
     auto_model = CsmForConditionalGeneration,
-    load_in_4bit = False, # Select True for 4bit - reduces memory usage
+    load_in_4bit = False,  # Select True for 4bit - reduces memory usage
 )
 
 
@@ -38,34 +39,41 @@
 
 model = FastModel.get_peft_model(
     model,
-    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
-    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
-                      "gate_proj", "up_proj", "down_proj",],
+    r = 32,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+    target_modules = [
+        "q_proj",
+        "k_proj",
+        "v_proj",
+        "o_proj",
+        "gate_proj",
+        "up_proj",
+        "down_proj",
+    ],
     lora_alpha = 32,
-    lora_dropout = 0, # Supports any, but = 0 is optimized
-    bias = "none",    # Supports any, but = "none" is optimized
+    lora_dropout = 0,  # Supports any, but = 0 is optimized
+    bias = "none",  # Supports any, but = "none" is optimized
     # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
-    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+    use_gradient_checkpointing = "unsloth",  # True or "unsloth" for very long context
     random_state = 3407,
     use_rslora = False,  # We support rank stabilized LoRA
-    loftq_config = None, # And LoftQ
+    loftq_config = None,  # And LoftQ
 )
 
 print("ā
 Model and LoRA adapters loaded successfully!")
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 2: Checking Model Class Type")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 assert isinstance(model, PeftModel), "Model should be an instance of PeftModel"
 print("ā
 Model is an instance of PeftModel!")
 
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 3: Checking Config Model Class Type")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
+
 
 def find_lora_base_model(model_to_inspect):
     current = model_to_inspect
@@ -74,19 +82,19 @@ def find_lora_base_model(model_to_inspect):
     if hasattr(current, "model"):
         current = current.model
     return current
-pass
 
 
 config_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model
 
-assert config_model.__class__.__name__ == base_model_class, f"Expected config_model class to be {base_model_class}"
+assert (
+    config_model.__class__.__name__ == base_model_class
+), f"Expected config_model class to be {base_model_class}"
 print("ā
 config_model returns correct Base Model class:", str(base_model_class))
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 4: Saving and Merging Model")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 with warnings.catch_warnings():
     warnings.simplefilter("error")  # Treat warnings as errors
@@ -96,49 +104,53 @@ def find_lora_base_model(model_to_inspect):
     except Exception as e:
         assert False, f"Model saving/merging failed with exception: {e}"
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 5: Loading Model for Inference")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 model, processor = FastModel.from_pretrained(
     model_name = "./csm",
-    max_seq_length= 2048, # Choose any for long context!
-    dtype = None, # Leave as None for auto-detection
+    max_seq_length = 2048,  # Choose any for long context!
+    dtype = None,  # Leave as None for auto-detection
     auto_model = CsmForConditionalGeneration,
-    load_in_4bit = False, # Select True for 4bit - reduces memory usage
+    load_in_4bit = False,  # Select True for 4bit - reduces memory usage
 )
 
 from transformers import AutoProcessor
+
 processor = AutoProcessor.from_pretrained("unsloth/csm-1b")
 
 print("ā
 Model loaded for inference successfully!")
 
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 6: Running Inference")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 from transformers import pipeline
 import torch
+
 output_audio_path = "csm_audio.wav"
 try:
-    text = "We just finished fine tuning a text to speech model... and it's pretty good!"
+    text = (
+        "We just finished fine tuning a text to speech model... and it's pretty good!"
+    )
     speaker_id = 0
-    inputs = processor(f"[{speaker_id}]{text}", add_special_tokens=True).to("cuda")
+    inputs = processor(f"[{speaker_id}]{text}", add_special_tokens = True).to("cuda")
     audio_values = model.generate(
         **inputs,
-        max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer speech increase this
+        max_new_tokens = 125,  # 125 tokens is 10 seconds of audio, for longer speech increase this
         # play with these parameters to get the best results
-        depth_decoder_temperature=0.6,
-        depth_decoder_top_k=0,
-        depth_decoder_top_p=0.9,
-        temperature=0.8,
-        top_k=50,
-        top_p=1.0,
+        depth_decoder_temperature = 0.6,
+        depth_decoder_top_k = 0,
+        depth_decoder_top_p = 0.9,
+        temperature = 0.8,
+        top_k = 50,
+        top_p = 1.0,
         #########################################################
-        output_audio=True
+        output_audio = True,
     )
     audio = audio_values[0].to(torch.float32).cpu().numpy()
     sf.write("example_without_context.wav", audio, 24000)
diff --git a/tests/saving/text_to_speech_models/test_lasa.py b/tests/saving/text_to_speech_models/test_lasa.py
index a4bc5eda8..eadc1bc06 100644
--- a/tests/saving/text_to_speech_models/test_lasa.py
+++ b/tests/saving/text_to_speech_models/test_lasa.py
@@ -1,6 +1,7 @@
 from unsloth import FastLanguageModel, FastModel
 from transformers import CsmForConditionalGeneration
 import torch
+
 # ruff: noqa
 import sys
 from pathlib import Path
@@ -22,6 +23,7 @@
 
 import soundfile as sf
 from xcodec2.modeling_xcodec2 import XCodec2Model
+
 XCODEC2_MODEL_NAME = "HKUST-Audio/xcodec2"
 SAMPLE_RATE = 16000
 DEVICE = "cuda"
@@ -32,18 +34,18 @@
 except Exception as e:
     raise f"ERROR loading XCodec2 model: {e}."
 
-codec_model.to('cpu')
+codec_model.to("cpu")
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 1: Loading Model and LoRA Adapters")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 max_seq_length = 2048
 model, tokenizer = FastLanguageModel.from_pretrained(
     model_name = "unsloth/Llasa-1B",
     max_seq_length = max_seq_length,
-    dtype = None, # Select None for auto detection
-    load_in_4bit = False, # Choose True for 4bit which reduces memory
+    dtype = None,  # Select None for auto detection
+    load_in_4bit = False,  # Choose True for 4bit which reduces memory
     # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
 )
 
@@ -52,33 +54,33 @@
 
 model = FastLanguageModel.get_peft_model(
     model,
-    r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+    r = 128,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
     target_modules = ["q_proj", "v_proj"],
     lora_alpha = 128,
-    lora_dropout = 0, # Supports any, but = 0 is optimized
-    bias = "none",    # Supports any, but = "none" is optimized
+    lora_dropout = 0,  # Supports any, but = 0 is optimized
+    bias = "none",  # Supports any, but = "none" is optimized
     # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
-    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+    use_gradient_checkpointing = "unsloth",  # True or "unsloth" for very long context
     random_state = 3407,
     use_rslora = False,  # We support rank stabilized LoRA
-    loftq_config = None, # And LoftQ
+    loftq_config = None,  # And LoftQ
 )
 
 print("ā
 Model and LoRA adapters loaded successfully!")
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 2: Checking Model Class Type")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 assert isinstance(model, PeftModel), "Model should be an instance of PeftModel"
 print("ā
 Model is an instance of PeftModel!")
 
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 3: Checking Config Model Class Type")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
+
 
 def find_lora_base_model(model_to_inspect):
     current = model_to_inspect
@@ -87,19 +89,19 @@ def find_lora_base_model(model_to_inspect):
     if hasattr(current, "model"):
         current = current.model
     return current
-pass
 
 
 config_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model
 
-assert config_model.__class__.__name__ == base_model_class, f"Expected config_model class to be {base_model_class}"
+assert (
+    config_model.__class__.__name__ == base_model_class
+), f"Expected config_model class to be {base_model_class}"
 print("ā
 config_model returns correct Base Model class:", str(base_model_class))
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 4: Saving and Merging Model")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 with warnings.catch_warnings():
     warnings.simplefilter("error")  # Treat warnings as errors
@@ -109,49 +111,50 @@ def find_lora_base_model(model_to_inspect):
     except Exception as e:
         assert False, f"Model saving/merging failed with exception: {e}"
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 5: Loading Model for Inference")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 model, tokenizer = FastLanguageModel.from_pretrained(
     model_name = "./lasa",
     max_seq_length = max_seq_length,
-    dtype = None, # Select None for auto detection
-    load_in_4bit = False, # Choose True for 4bit which reduces memory
+    dtype = None,  # Select None for auto detection
+    load_in_4bit = False,  # Choose True for 4bit which reduces memory
     # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
 )
 
-#from transformers import AutoProcessor
-#processor = AutoProcessor.from_pretrained("unsloth/csm-1b")
+# from transformers import AutoProcessor
+# processor = AutoProcessor.from_pretrained("unsloth/csm-1b")
 
 print("ā
 Model loaded for inference successfully!")
 
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 6: Running Inference")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 from transformers import pipeline
 import torch
+
 output_audio_path = "lasa_audio.wav"
 input_text = "Hey there my name is Elise,  and I'm a speech generation model that can sound like a person."
 
 FastLanguageModel.for_inference(model)
 
-def ids_to_speech_tokens(speech_ids):
 
+def ids_to_speech_tokens(speech_ids):
     speech_tokens_str = []
     for speech_id in speech_ids:
         speech_tokens_str.append(f"<|s_{speech_id}|>")
     return speech_tokens_str
 
-def extract_speech_ids(speech_tokens_str):
 
+def extract_speech_ids(speech_tokens_str):
     speech_ids = []
     for token_str in speech_tokens_str:
-        if token_str.startswith('<|s_') and token_str.endswith('|>'):
+        if token_str.startswith("<|s_") and token_str.endswith("|>"):
             num_str = token_str[4:-2]
 
             num = int(num_str)
@@ -160,40 +163,40 @@ def extract_speech_ids(speech_tokens_str):
             print(f"Unexpected token: {token_str}")
     return speech_ids
 
-#TTS start!
+
+# TTS start!
 with torch.inference_mode():
-    with torch.amp.autocast('cuda',dtype=model.dtype):
-        formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
+    with torch.amp.autocast("cuda", dtype = model.dtype):
+        formatted_text = (
+            f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
+        )
 
         # Tokenize the text
         chat = [
             {"role": "user", "content": "Convert the text to speech:" + formatted_text},
-            {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
+            {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
         ]
 
         input_ids = tokenizer.apply_chat_template(
-            chat,
-            tokenize=True,
-            return_tensors='pt',
-            continue_final_message=True
+            chat, tokenize = True, return_tensors = "pt", continue_final_message = True
         )
-        input_ids = input_ids.to('cuda')
+        input_ids = input_ids.to("cuda")
 
-        speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
+        speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
 
         # Generate the speech autoregressively
         outputs = model.generate(
             input_ids,
-            max_length=2048,  # We trained our model with a max length of 2048
-            eos_token_id= speech_end_id ,
-            do_sample=True,
-            top_p=1.2,           #  Adjusts the diversity of generated content
-            temperature=1.2,   #  Controls randomness in output
+            max_length = 2048,  # We trained our model with a max length of 2048
+            eos_token_id = speech_end_id,
+            do_sample = True,
+            top_p = 1.2,  #  Adjusts the diversity of generated content
+            temperature = 1.2,  #  Controls randomness in output
         )
     # Extract the speech tokens
-    generated_ids = outputs[0][input_ids.shape[1]:-1]
+    generated_ids = outputs[0][input_ids.shape[1] : -1]
 
-    speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+    speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens = True)
 
     # Convert  token <|s_23456|> to int 23456
     speech_tokens = extract_speech_ids(speech_tokens)
diff --git a/tests/saving/text_to_speech_models/test_orpheus.py b/tests/saving/text_to_speech_models/test_orpheus.py
index 813a80aac..8bd94b29c 100644
--- a/tests/saving/text_to_speech_models/test_orpheus.py
+++ b/tests/saving/text_to_speech_models/test_orpheus.py
@@ -1,6 +1,7 @@
 from unsloth import FastLanguageModel, FastModel
 from transformers import CsmForConditionalGeneration
 import torch
+
 # ruff: noqa
 import sys
 from pathlib import Path
@@ -20,18 +21,19 @@
 
 import soundfile as sf
 from snac import SNAC
+
 snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
 snac_model = snac_model.to("cuda")
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 1: Loading Model and LoRA Adapters")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 model, tokenizer = FastLanguageModel.from_pretrained(
     model_name = "unsloth/orpheus-3b-0.1-ft",
-    max_seq_length= 2048, # Choose any for long context!
-    dtype = None, # Select None for auto detection
-    load_in_4bit = False, # Select True for 4bit which reduces memory usage
+    max_seq_length = 2048,  # Choose any for long context!
+    dtype = None,  # Select None for auto detection
+    load_in_4bit = False,  # Select True for 4bit which reduces memory usage
     # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
 )
 
@@ -40,33 +42,40 @@
 
 model = FastLanguageModel.get_peft_model(
     model,
-    r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
-    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
-                      "gate_proj", "up_proj", "down_proj",],
+    r = 64,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+    target_modules = [
+        "q_proj",
+        "k_proj",
+        "v_proj",
+        "o_proj",
+        "gate_proj",
+        "up_proj",
+        "down_proj",
+    ],
     lora_alpha = 64,
-    lora_dropout = 0, # Supports any, but = 0 is optimized
-    bias = "none",    # Supports any, but = "none" is optimized
+    lora_dropout = 0,  # Supports any, but = 0 is optimized
+    bias = "none",  # Supports any, but = "none" is optimized
     # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
-    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+    use_gradient_checkpointing = "unsloth",  # True or "unsloth" for very long context
     random_state = 3407,
     use_rslora = False,  # We support rank stabilized LoRA
-    loftq_config = None, # And LoftQ
+    loftq_config = None,  # And LoftQ
 )
 print("ā
 Model and LoRA adapters loaded successfully!")
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 2: Checking Model Class Type")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 assert isinstance(model, PeftModel), "Model should be an instance of PeftModel"
 print("ā
 Model is an instance of PeftModel!")
 
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 3: Checking Config Model Class Type")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
+
 
 def find_lora_base_model(model_to_inspect):
     current = model_to_inspect
@@ -75,19 +84,19 @@ def find_lora_base_model(model_to_inspect):
     if hasattr(current, "model"):
         current = current.model
     return current
-pass
 
 
 config_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model
 
-assert config_model.__class__.__name__ == base_model_class, f"Expected config_model class to be {base_model_class}"
+assert (
+    config_model.__class__.__name__ == base_model_class
+), f"Expected config_model class to be {base_model_class}"
 print("ā
 config_model returns correct Base Model class:", str(base_model_class))
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 4: Saving and Merging Model")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 with warnings.catch_warnings():
     warnings.simplefilter("error")  # Treat warnings as errors
@@ -97,34 +106,34 @@ def find_lora_base_model(model_to_inspect):
     except Exception as e:
         assert False, f"Model saving/merging failed with exception: {e}"
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 5: Loading Model for Inference")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 model, tokenizer = FastLanguageModel.from_pretrained(
     model_name = "unsloth/orpheus-3b-0.1-ft",
-    max_seq_length= 2048, # Choose any for long context!
-    dtype = None, # Select None for auto detection
-    load_in_4bit = False, # Select True for 4bit which reduces memory usage
+    max_seq_length = 2048,  # Choose any for long context!
+    dtype = None,  # Select None for auto detection
+    load_in_4bit = False,  # Select True for 4bit which reduces memory usage
     # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
 )
 
-#from transformers import AutoProcessor
-#processor = AutoProcessor.from_pretrained("unsloth/csm-1b")
+# from transformers import AutoProcessor
+# processor = AutoProcessor.from_pretrained("unsloth/csm-1b")
 
 print("ā
 Model loaded for inference successfully!")
 
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 6: Running Inference")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
-#@title Run Inference
+# @title Run Inference
 
 
-FastLanguageModel.for_inference(model) # Enable native 2x faster inference
+FastLanguageModel.for_inference(model)  # Enable native 2x faster inference
 
 # Moving snac_model cuda to cpu
 snac_model.to("cpu")
@@ -132,59 +141,73 @@ def find_lora_base_model(model_to_inspect):
     "Hey there my name is Elise,  and I'm a speech generation model that can sound like a person.",
 ]
 
-chosen_voice = None # None for single-speaker
+chosen_voice = None  # None for single-speaker
 
 prompts_ = [(f"{chosen_voice}: " + p) if chosen_voice else p for p in prompts]
 
 all_input_ids = []
 
 for prompt in prompts_:
-  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
-  all_input_ids.append(input_ids)
+    input_ids = tokenizer(prompt, return_tensors = "pt").input_ids
+    all_input_ids.append(input_ids)
 
-start_token = torch.tensor([[ 128259]], dtype=torch.int64) # Start of human
-end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
+start_token = torch.tensor([[128259]], dtype = torch.int64)  # Start of human
+end_tokens = torch.tensor(
+    [[128009, 128260]], dtype = torch.int64
+)  # End of text, End of human
 
 all_modified_input_ids = []
 for input_ids in all_input_ids:
-  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
-  all_modified_input_ids.append(modified_input_ids)
+    modified_input_ids = torch.cat(
+        [start_token, input_ids, end_tokens], dim = 1
+    )  # SOH SOT Text EOT EOH
+    all_modified_input_ids.append(modified_input_ids)
 
 all_padded_tensors = []
 all_attention_masks = []
-max_length = max([modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids])
+max_length = max(
+    [modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids]
+)
 for modified_input_ids in all_modified_input_ids:
-  padding = max_length - modified_input_ids.shape[1]
-  padded_tensor = torch.cat([torch.full((1, padding), 128263, dtype=torch.int64), modified_input_ids], dim=1)
-  attention_mask = torch.cat([torch.zeros((1, padding), dtype=torch.int64), torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64)], dim=1)
-  all_padded_tensors.append(padded_tensor)
-  all_attention_masks.append(attention_mask)
-
-all_padded_tensors = torch.cat(all_padded_tensors, dim=0)
-all_attention_masks = torch.cat(all_attention_masks, dim=0)
+    padding = max_length - modified_input_ids.shape[1]
+    padded_tensor = torch.cat(
+        [torch.full((1, padding), 128263, dtype = torch.int64), modified_input_ids], dim = 1
+    )
+    attention_mask = torch.cat(
+        [
+            torch.zeros((1, padding), dtype = torch.int64),
+            torch.ones((1, modified_input_ids.shape[1]), dtype = torch.int64),
+        ],
+        dim = 1,
+    )
+    all_padded_tensors.append(padded_tensor)
+    all_attention_masks.append(attention_mask)
+
+all_padded_tensors = torch.cat(all_padded_tensors, dim = 0)
+all_attention_masks = torch.cat(all_attention_masks, dim = 0)
 
 input_ids = all_padded_tensors.to("cuda")
 attention_mask = all_attention_masks.to("cuda")
 generated_ids = model.generate(
-      input_ids=input_ids,
-      attention_mask=attention_mask,
-      max_new_tokens=1200,
-      do_sample=True,
-      temperature=0.6,
-      top_p=0.95,
-      repetition_penalty=1.1,
-      num_return_sequences=1,
-      eos_token_id=128258,
-     use_cache = True
-  )
+    input_ids = input_ids,
+    attention_mask = attention_mask,
+    max_new_tokens = 1200,
+    do_sample = True,
+    temperature = 0.6,
+    top_p = 0.95,
+    repetition_penalty = 1.1,
+    num_return_sequences = 1,
+    eos_token_id = 128258,
+    use_cache = True,
+)
 token_to_find = 128257
 token_to_remove = 128258
 
-token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
+token_indices = (generated_ids == token_to_find).nonzero(as_tuple = True)
 
 if len(token_indices[1]) > 0:
     last_occurrence_idx = token_indices[1][-1].item()
-    cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
+    cropped_tensor = generated_ids[:, last_occurrence_idx + 1 :]
 else:
     cropped_tensor = generated_ids
 
@@ -207,34 +230,38 @@ def find_lora_base_model(model_to_inspect):
 
 
 def redistribute_codes(code_list):
-  layer_1 = []
-  layer_2 = []
-  layer_3 = []
-  for i in range((len(code_list)+1)//7):
-    layer_1.append(code_list[7*i])
-    layer_2.append(code_list[7*i+1]-4096)
-    layer_3.append(code_list[7*i+2]-(2*4096))
-    layer_3.append(code_list[7*i+3]-(3*4096))
-    layer_2.append(code_list[7*i+4]-(4*4096))
-    layer_3.append(code_list[7*i+5]-(5*4096))
-    layer_3.append(code_list[7*i+6]-(6*4096))
-  codes = [torch.tensor(layer_1).unsqueeze(0),
-         torch.tensor(layer_2).unsqueeze(0),
-         torch.tensor(layer_3).unsqueeze(0)]
-
-  # codes = [c.to("cuda") for c in codes]
-  audio_hat = snac_model.decode(codes)
-  return audio_hat
+    layer_1 = []
+    layer_2 = []
+    layer_3 = []
+    for i in range((len(code_list) + 1) // 7):
+        layer_1.append(code_list[7 * i])
+        layer_2.append(code_list[7 * i + 1] - 4096)
+        layer_3.append(code_list[7 * i + 2] - (2 * 4096))
+        layer_3.append(code_list[7 * i + 3] - (3 * 4096))
+        layer_2.append(code_list[7 * i + 4] - (4 * 4096))
+        layer_3.append(code_list[7 * i + 5] - (5 * 4096))
+        layer_3.append(code_list[7 * i + 6] - (6 * 4096))
+    codes = [
+        torch.tensor(layer_1).unsqueeze(0),
+        torch.tensor(layer_2).unsqueeze(0),
+        torch.tensor(layer_3).unsqueeze(0),
+    ]
+
+    # codes = [c.to("cuda") for c in codes]
+    audio_hat = snac_model.decode(codes)
+    return audio_hat
+
 
 my_samples = []
 for code_list in code_lists:
-  samples = redistribute_codes(code_list)
-  my_samples.append(samples)
+    samples = redistribute_codes(code_list)
+    my_samples.append(samples)
 output_path = "orpheus_audio.wav"
 try:
     for i, samples in enumerate(my_samples):
         audio_data = samples.detach().squeeze().cpu().numpy()
         import soundfile as sf
+
         sf.write(output_path, audio_data, 24000)  # Explicitly pass sample rate
         print(f"ā
 Audio saved to {output_path}!")
 except Exception as e:
@@ -242,6 +269,7 @@ def redistribute_codes(code_list):
 
 # Verify the file exists
 import os
+
 assert os.path.exists(output_path), f"Audio file not found at {output_path}"
 print("ā
 Audio file exists on disk!")
 del my_samples, samples
diff --git a/tests/saving/text_to_speech_models/test_whisper.py b/tests/saving/text_to_speech_models/test_whisper.py
index f29213c3a..bb443db42 100644
--- a/tests/saving/text_to_speech_models/test_whisper.py
+++ b/tests/saving/text_to_speech_models/test_whisper.py
@@ -1,6 +1,7 @@
 from unsloth import FastLanguageModel, FastModel
 from transformers import WhisperForConditionalGeneration, WhisperProcessor
 import torch
+
 # ruff: noqa
 import sys
 from pathlib import Path
@@ -21,15 +22,15 @@
 
 import soundfile as sf
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 1: Loading Model and LoRA Adapters")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 model, tokenizer = FastModel.from_pretrained(
     model_name = "unsloth/whisper-large-v3",
-    dtype = None, # Leave as None for auto detection
-    load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
+    dtype = None,  # Leave as None for auto detection
+    load_in_4bit = False,  # Set to True to do 4bit quantization which reduces memory
     auto_model = WhisperForConditionalGeneration,
     whisper_language = "English",
     whisper_task = "transcribe",
@@ -38,41 +39,41 @@
 
 
 base_model_class = model.__class__.__name__
-#https://github.com/huggingface/transformers/issues/37172
+# https://github.com/huggingface/transformers/issues/37172
 model.generation_config.input_ids = model.generation_config.forced_decoder_ids
 model.generation_config.forced_decoder_ids = None
 
 
 model = FastModel.get_peft_model(
     model,
-    r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+    r = 64,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
     target_modules = ["q_proj", "v_proj"],
     lora_alpha = 64,
-    lora_dropout = 0, # Supports any, but = 0 is optimized
-    bias = "none",    # Supports any, but = "none" is optimized
+    lora_dropout = 0,  # Supports any, but = 0 is optimized
+    bias = "none",  # Supports any, but = "none" is optimized
     # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
-    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+    use_gradient_checkpointing = "unsloth",  # True or "unsloth" for very long context
     random_state = 3407,
     use_rslora = False,  # We support rank stabilized LoRA
-    loftq_config = None, # And LoftQ
-    task_type = None, # ** MUST set this for Whisper **
+    loftq_config = None,  # And LoftQ
+    task_type = None,  # ** MUST set this for Whisper **
 )
 
 print("ā
 Model and LoRA adapters loaded successfully!")
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 2: Checking Model Class Type")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 assert isinstance(model, PeftModel), "Model should be an instance of PeftModel"
 print("ā
 Model is an instance of PeftModel!")
 
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 3: Checking Config Model Class Type")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
+
 
 def find_lora_base_model(model_to_inspect):
     current = model_to_inspect
@@ -81,19 +82,19 @@ def find_lora_base_model(model_to_inspect):
     if hasattr(current, "model"):
         current = current.model
     return current
-pass
 
 
 config_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model
 
-assert config_model.__class__.__name__ == base_model_class, f"Expected config_model class to be {base_model_class}"
+assert (
+    config_model.__class__.__name__ == base_model_class
+), f"Expected config_model class to be {base_model_class}"
 print("ā
 config_model returns correct Base Model class:", str(base_model_class))
 
 
-
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 4: Saving and Merging Model")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 with warnings.catch_warnings():
     warnings.simplefilter("error")  # Treat warnings as errors
@@ -103,15 +104,15 @@ def find_lora_base_model(model_to_inspect):
     except Exception as e:
         assert False, f"Model saving/merging failed with exception: {e}"
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 5: Loading Model for Inference")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 model, tokenizer = FastModel.from_pretrained(
     model_name = "./whisper",
-    dtype = None, # Leave as None for auto detection
-    load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
+    dtype = None,  # Leave as None for auto detection
+    load_in_4bit = False,  # Set to True to do 4bit quantization which reduces memory
     auto_model = WhisperForConditionalGeneration,
     whisper_language = "English",
     whisper_task = "transcribe",
@@ -123,9 +124,9 @@ def find_lora_base_model(model_to_inspect):
 
 print("ā
 Model loaded for inference successfully!")
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 6: Downloading Sample Audio File")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 audio_url = "https://upload.wikimedia.org/wikipedia/commons/5/5b/Speech_12dB_s16.flac"
 audio_file = "Speech_12dB_s16.flac"
@@ -134,7 +135,7 @@ def find_lora_base_model(model_to_inspect):
     headers = {
         "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
     }
-    response = requests.get(audio_url, headers=headers)
+    response = requests.get(audio_url, headers = headers)
     response.raise_for_status()
     with open(audio_file, "wb") as f:
         f.write(response.content)
@@ -142,24 +143,25 @@ def find_lora_base_model(model_to_inspect):
 except Exception as e:
     assert False, f"Failed to download audio file: {e}"
 
-print(f"\n{'='*80}")
+print(f"\n{'=' * 80}")
 print("š SECTION 7: Running Inference")
-print(f"{'='*80}")
+print(f"{'=' * 80}")
 
 
 from transformers import pipeline
 import torch
+
 FastModel.for_inference(model)
 model.eval()
-#Create pipeline without specifying the device
+# Create pipeline without specifying the device
 whisper = pipeline(
     "automatic-speech-recognition",
-    model=model,
-    tokenizer=tokenizer.tokenizer,
-    feature_extractor=tokenizer.feature_extractor,
-    processor=tokenizer,
-    return_language=True,
-    torch_dtype=torch.float16  # Remove the device parameter
+    model = model,
+    tokenizer = tokenizer.tokenizer,
+    feature_extractor = tokenizer.feature_extractor,
+    processor = tokenizer,
+    return_language = True,
+    torch_dtype = torch.float16,  # Remove the device parameter
 )
 # Example usage
 audio_file = "Speech_12dB_s16.flac"
@@ -179,9 +181,13 @@ def find_lora_base_model(model_to_inspect):
 ]
 
 transcribed_lower = transcribed_text["text"].lower()
-all_phrases_found = all(phrase.lower() in transcribed_lower for phrase in expected_phrases)
+all_phrases_found = all(
+    phrase.lower() in transcribed_lower for phrase in expected_phrases
+)
 
-assert all_phrases_found, f"Expected phrases not found in transcription: {transcribed_text['text']}"
+assert (
+    all_phrases_found
+), f"Expected phrases not found in transcription: {transcribed_text['text']}"
 print("ā
 Transcription contains all expected phrases!")
 
 
diff --git a/tests/saving/vision_models/test_index_file_sharded_model.py b/tests/saving/vision_models/test_index_file_sharded_model.py
index 09c4bb19e..f73716984 100644
--- a/tests/saving/vision_models/test_index_file_sharded_model.py
+++ b/tests/saving/vision_models/test_index_file_sharded_model.py
@@ -21,7 +21,7 @@
 ## Dataset Preparation"""
 
 print("\nš Loading and preparing dataset...")
-dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
+dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
 # To select the first 2000 examples
 train_dataset = dataset.select(range(2000))
 
@@ -31,6 +31,8 @@
 print(f"ā
 Dataset loaded successfully!")
 print(f"   š Training samples: {len(train_dataset)}")
 print(f"   š Evaluation samples: {len(eval_dataset)}")
+
+
 # Convert dataset to OAI messages
 def format_data(sample):
     return {
@@ -59,6 +61,7 @@ def format_data(sample):
         ],
     }
 
+
 print("\nš Formatting dataset for vision training...")
 system_message = "You are an expert french ocr system."
 # Convert dataset to OAI messages
@@ -78,11 +81,11 @@ def format_data(sample):
 try:
     model, tokenizer = FastVisionModel.from_pretrained(
         # model_name = "unsloth/Qwen2-VL-7B-Instruct",
-        model_name="unsloth/Qwen2-VL-7B-Instruct",
-        max_seq_length=2048,  # Choose any for long context!
-        load_in_4bit=True,  # 4 bit quantization to reduce memory
-        load_in_8bit=False,  # [NEW!] A bit more accurate, uses 2x memory
-        full_finetuning=False,  # [NEW!] We have full finetuning now!
+        model_name = "unsloth/Qwen2-VL-7B-Instruct",
+        max_seq_length = 2048,  # Choose any for long context!
+        load_in_4bit = True,  # 4 bit quantization to reduce memory
+        load_in_8bit = False,  # [NEW!] A bit more accurate, uses 2x memory
+        full_finetuning = False,  # [NEW!] We have full finetuning now!
     )
 except Exception as e:
     print(f"ā Failed to load base model: {e}")
@@ -93,18 +96,18 @@ def format_data(sample):
 try:
     model = FastVisionModel.get_peft_model(
         model,
-        finetune_vision_layers=True,  # Turn off for just text!
-        finetune_language_layers=True,  # Should leave on!
-        finetune_attention_modules=True,  # Attention good for GRPO
-        finetune_mlp_modules=True,  # SHould leave on always!
-        r=16,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
-        lora_alpha=32,
-        lora_dropout=0,  # Supports any, but = 0 is optimized
-        bias="none",  # Supports any, but = "none" is optimized
-        use_gradient_checkpointing="unsloth",  # True or "unsloth" for very long context
-        random_state=3407,
-        use_rslora=False,  # We support rank stabilized LoRA
-        loftq_config=None,  # And LoftQ
+        finetune_vision_layers = True,  # Turn off for just text!
+        finetune_language_layers = True,  # Should leave on!
+        finetune_attention_modules = True,  # Attention good for GRPO
+        finetune_mlp_modules = True,  # SHould leave on always!
+        r = 16,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+        lora_alpha = 32,
+        lora_dropout = 0,  # Supports any, but = 0 is optimized
+        bias = "none",  # Supports any, but = "none" is optimized
+        use_gradient_checkpointing = "unsloth",  # True or "unsloth" for very long context
+        random_state = 3407,
+        use_rslora = False,  # We support rank stabilized LoRA
+        loftq_config = None,  # And LoftQ
     )
     print("ā
 LoRA configuration applied successfully!")
     print(f"   šÆ LoRA rank (r): 16")
@@ -125,40 +128,40 @@ def format_data(sample):
 
 try:
     trainer = SFTTrainer(
-        model=model,
-        tokenizer=tokenizer,
-        data_collator=UnslothVisionDataCollator(model, tokenizer),
-        train_dataset=train_dataset,
-        args=SFTConfig(
+        model = model,
+        tokenizer = tokenizer,
+        data_collator = UnslothVisionDataCollator(model, tokenizer),
+        train_dataset = train_dataset,
+        args = SFTConfig(
             # per_device_train_batch_size = 4,
             # gradient_accumulation_steps = 8,
-            per_device_train_batch_size=2,
-            gradient_accumulation_steps=4,
-            gradient_checkpointing=True,
-            gradient_checkpointing_kwargs={
+            per_device_train_batch_size = 2,
+            gradient_accumulation_steps = 4,
+            gradient_checkpointing = True,
+            gradient_checkpointing_kwargs = {
                 "use_reentrant": False
             },  # use reentrant checkpointing
-            max_grad_norm=0.3,  # max gradient norm based on QLoRA paper
-            warmup_ratio=0.03,
+            max_grad_norm = 0.3,  # max gradient norm based on QLoRA paper
+            warmup_ratio = 0.03,
             # num_train_epochs = 2, # Set this instead of max_steps for full training runs
-            max_steps=10,
-            learning_rate=2e-4,
-            fp16=not is_bf16_supported(),
-            bf16=is_bf16_supported(),
-            logging_steps=5,
-            save_strategy="epoch",
-            optim="adamw_torch_fused",
-            weight_decay=0.01,
-            lr_scheduler_type="linear",
-            seed=3407,
-            output_dir="checkpoints",
-            report_to="none",  # For Weights and Biases
+            max_steps = 10,
+            learning_rate = 2e-4,
+            fp16 = not is_bf16_supported(),
+            bf16 = is_bf16_supported(),
+            logging_steps = 5,
+            save_strategy = "epoch",
+            optim = "adamw_torch_fused",
+            weight_decay = 0.01,
+            lr_scheduler_type = "linear",
+            seed = 3407,
+            output_dir = "checkpoints",
+            report_to = "none",  # For Weights and Biases
             # You MUST put the below items for vision finetuning:
-            remove_unused_columns=False,
-            dataset_text_field="",
-            dataset_kwargs={"skip_prepare_dataset": True},
-            dataset_num_proc=4,
-            max_seq_length=2048,
+            remove_unused_columns = False,
+            dataset_text_field = "",
+            dataset_kwargs = {"skip_prepare_dataset": True},
+            dataset_num_proc = 4,
+            max_seq_length = 2048,
         ),
     )
     print("ā
 Trainer setup completed!")
@@ -218,7 +221,7 @@ def format_data(sample):
     print("=== UPLOADING MODEL TO HUB ===".center(80))
     print("=" * 80 + "\n")
     print(f"š Uploading to repository: {repo_name}")
-    model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
+    model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
     success["upload"] = True
     print("ā
 Model uploaded successfully!")
 except Exception as e:
@@ -230,8 +233,8 @@ def format_data(sample):
     print("\n" + "=" * 80)
     print("=== VERIFYING REPO CONTENTS ===".center(80))
     print("=" * 80 + "\n")
-    fs = HfFileSystem(token=hf_token)
-    file_list = fs.ls(repo_name, detail=True)
+    fs = HfFileSystem(token = hf_token)
+    file_list = fs.ls(repo_name, detail = True)
     safetensors_found = any(
         file["name"].endswith("model.safetensors.index.json") for file in file_list
     )
diff --git a/tests/saving/vision_models/test_push_to_hub_merged.py b/tests/saving/vision_models/test_push_to_hub_merged.py
index 372d22573..74fa05898 100644
--- a/tests/saving/vision_models/test_push_to_hub_merged.py
+++ b/tests/saving/vision_models/test_push_to_hub_merged.py
@@ -22,7 +22,7 @@
 ## Dataset Preparation"""
 
 print("\nš Loading and preparing dataset...")
-dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
+dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
 # To select the first 2000 examples
 train_dataset = dataset.select(range(2000))
 
@@ -32,6 +32,8 @@
 print(f"ā
 Dataset loaded successfully!")
 print(f"   š Training samples: {len(train_dataset)}")
 print(f"   š Evaluation samples: {len(eval_dataset)}")
+
+
 # Convert dataset to OAI messages
 def format_data(sample):
     return {
@@ -60,6 +62,7 @@ def format_data(sample):
         ],
     }
 
+
 print("\nš Formatting dataset for vision training...")
 system_message = "You are an expert french ocr system."
 # Convert dataset to OAI messages
@@ -79,11 +82,11 @@ def format_data(sample):
 try:
     model, tokenizer = FastVisionModel.from_pretrained(
         # model_name = "unsloth/Qwen2-VL-7B-Instruct",
-        model_name="unsloth/Qwen2-VL-2B-Instruct",
-        max_seq_length=2048,  # Choose any for long context!
-        load_in_4bit=True,  # 4 bit quantization to reduce memory
-        load_in_8bit=False,  # [NEW!] A bit more accurate, uses 2x memory
-        full_finetuning=False,  # [NEW!] We have full finetuning now!
+        model_name = "unsloth/Qwen2-VL-2B-Instruct",
+        max_seq_length = 2048,  # Choose any for long context!
+        load_in_4bit = True,  # 4 bit quantization to reduce memory
+        load_in_8bit = False,  # [NEW!] A bit more accurate, uses 2x memory
+        full_finetuning = False,  # [NEW!] We have full finetuning now!
     )
 except Exception as e:
     print(f"ā Failed to load base model: {e}")
@@ -94,18 +97,18 @@ def format_data(sample):
 try:
     model = FastVisionModel.get_peft_model(
         model,
-        finetune_vision_layers=True,  # Turn off for just text!
-        finetune_language_layers=True,  # Should leave on!
-        finetune_attention_modules=True,  # Attention good for GRPO
-        finetune_mlp_modules=True,  # SHould leave on always!
-        r=16,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
-        lora_alpha=32,
-        lora_dropout=0,  # Supports any, but = 0 is optimized
-        bias="none",  # Supports any, but = "none" is optimized
-        use_gradient_checkpointing="unsloth",  # True or "unsloth" for very long context
-        random_state=3407,
-        use_rslora=False,  # We support rank stabilized LoRA
-        loftq_config=None,  # And LoftQ
+        finetune_vision_layers = True,  # Turn off for just text!
+        finetune_language_layers = True,  # Should leave on!
+        finetune_attention_modules = True,  # Attention good for GRPO
+        finetune_mlp_modules = True,  # SHould leave on always!
+        r = 16,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+        lora_alpha = 32,
+        lora_dropout = 0,  # Supports any, but = 0 is optimized
+        bias = "none",  # Supports any, but = "none" is optimized
+        use_gradient_checkpointing = "unsloth",  # True or "unsloth" for very long context
+        random_state = 3407,
+        use_rslora = False,  # We support rank stabilized LoRA
+        loftq_config = None,  # And LoftQ
     )
     print("ā
 LoRA configuration applied successfully!")
     print(f"   šÆ LoRA rank (r): 16")
@@ -126,40 +129,40 @@ def format_data(sample):
 
 try:
     trainer = SFTTrainer(
-        model=model,
-        tokenizer=tokenizer,
-        data_collator=UnslothVisionDataCollator(model, tokenizer),
-        train_dataset=train_dataset,
-        args=SFTConfig(
+        model = model,
+        tokenizer = tokenizer,
+        data_collator = UnslothVisionDataCollator(model, tokenizer),
+        train_dataset = train_dataset,
+        args = SFTConfig(
             # per_device_train_batch_size = 4,
             # gradient_accumulation_steps = 8,
-            per_device_train_batch_size=2,
-            gradient_accumulation_steps=4,
-            gradient_checkpointing=True,
-            gradient_checkpointing_kwargs={
+            per_device_train_batch_size = 2,
+            gradient_accumulation_steps = 4,
+            gradient_checkpointing = True,
+            gradient_checkpointing_kwargs = {
                 "use_reentrant": False
             },  # use reentrant checkpointing
-            max_grad_norm=0.3,  # max gradient norm based on QLoRA paper
-            warmup_ratio=0.03,
+            max_grad_norm = 0.3,  # max gradient norm based on QLoRA paper
+            warmup_ratio = 0.03,
             # num_train_epochs = 2, # Set this instead of max_steps for full training runs
-            max_steps=10,
-            learning_rate=2e-4,
-            fp16=not is_bf16_supported(),
-            bf16=is_bf16_supported(),
-            logging_steps=5,
-            save_strategy="epoch",
-            optim="adamw_torch_fused",
-            weight_decay=0.01,
-            lr_scheduler_type="linear",
-            seed=3407,
-            output_dir="checkpoints",
-            report_to="none",  # For Weights and Biases
+            max_steps = 10,
+            learning_rate = 2e-4,
+            fp16 = not is_bf16_supported(),
+            bf16 = is_bf16_supported(),
+            logging_steps = 5,
+            save_strategy = "epoch",
+            optim = "adamw_torch_fused",
+            weight_decay = 0.01,
+            lr_scheduler_type = "linear",
+            seed = 3407,
+            output_dir = "checkpoints",
+            report_to = "none",  # For Weights and Biases
             # You MUST put the below items for vision finetuning:
-            remove_unused_columns=False,
-            dataset_text_field="",
-            dataset_kwargs={"skip_prepare_dataset": True},
-            dataset_num_proc=4,
-            max_seq_length=2048,
+            remove_unused_columns = False,
+            dataset_text_field = "",
+            dataset_kwargs = {"skip_prepare_dataset": True},
+            dataset_num_proc = 4,
+            max_seq_length = 2048,
         ),
     )
     print("ā
 Trainer setup completed!")
@@ -218,7 +221,7 @@ def format_data(sample):
     print("=== UPLOADING MODEL TO HUB ===".center(80))
     print("=" * 80 + "\n")
     print(f"š Uploading to repository: {repo_name}")
-    model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
+    model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
     success["upload"] = True
     print("ā
 Model uploaded successfully!")
 except Exception as e:
diff --git a/tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py b/tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py
index 0bf548b41..ebe078c73 100644
--- a/tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py
+++ b/tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py
@@ -22,38 +22,42 @@
 ## Dataset Preparation
 from datasets import load_dataset
 
-dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", 'en', split="train")
+dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
 # To select the first 2000 examples
 train_dataset = dataset.select(range(2000))
 
 # To select the next 200 examples for evaluation
 eval_dataset = dataset.select(range(2000, 2200))
 
+
 # Convert dataset to OAI messages
 def format_data(sample):
-    return {"messages": [
-                {
-                    "role": "system",
-                    "content": [{"type": "text", "text": system_message}],
-                },
-                {
-                    "role": "user",
-                    "content": [
-                        {
-                            "type": "text",
-                            "text": sample["question"],
-                        },{
-                            "type": "image",
-                            "image": sample["image"],
-                        }
-                    ],
-                },
-                {
-                    "role": "assistant",
-                    "content": [{"type": "text", "text": sample["answer"]}],
-                },
-            ],
-        }
+    return {
+        "messages": [
+            {
+                "role": "system",
+                "content": [{"type": "text", "text": system_message}],
+            },
+            {
+                "role": "user",
+                "content": [
+                    {
+                        "type": "text",
+                        "text": sample["question"],
+                    },
+                    {
+                        "type": "image",
+                        "image": sample["image"],
+                    },
+                ],
+            },
+            {
+                "role": "assistant",
+                "content": [{"type": "text", "text": sample["answer"]}],
+            },
+        ],
+    }
+
 
 system_message = "You are an expert french ocr system."
 # Convert dataset to OAI messages
@@ -78,42 +82,44 @@ def format_data(sample):
 
 model, tokenizer = FastVisionModel.from_pretrained(
     model_name = "unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit",
-    max_seq_length = 2048, # Choose any for long context!
+    max_seq_length = 2048,  # Choose any for long context!
     load_in_4bit = True,  # 4 bit quantization to reduce memory
-    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
-    full_finetuning = False, # [NEW!] We have full finetuning now!
+    load_in_8bit = False,  # [NEW!] A bit more accurate, uses 2x memory
+    full_finetuning = False,  # [NEW!] We have full finetuning now!
 )
 
 # benchmark base model performance
 model_name = "Unsloth Base model"
 FastVisionModel.for_inference(model)
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_base_model_results")
+avg_wer, avg_cer = ocr_evaluator.evaluate_model(
+    model, tokenizer, eval_dataset, output_dir = "unsloth_base_model_results"
+)
 ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 ## Lora Finetuning
 model = FastVisionModel.get_peft_model(
     model,
-    finetune_vision_layers     = True, # Turn off for just text!
-    finetune_language_layers   = True,  # Should leave on!
+    finetune_vision_layers = True,  # Turn off for just text!
+    finetune_language_layers = True,  # Should leave on!
     finetune_attention_modules = True,  # Attention good for GRPO
-    finetune_mlp_modules       = True,  # SHould leave on always!
-
-    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
-    #target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
-                      #"gate_proj", "up_proj", "down_proj",],
+    finetune_mlp_modules = True,  # SHould leave on always!
+    r = 16,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+    # target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
+    # "gate_proj", "up_proj", "down_proj",],
     lora_alpha = 32,
-    lora_dropout = 0, # Supports any, but = 0 is optimized
-    bias = "none",    # Supports any, but = "none" is optimized
+    lora_dropout = 0,  # Supports any, but = 0 is optimized
+    bias = "none",  # Supports any, but = "none" is optimized
     # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
-    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+    use_gradient_checkpointing = "unsloth",  # True or "unsloth" for very long context
     random_state = 3407,
     use_rslora = False,  # We support rank stabilized LoRA
-    loftq_config = None, # And LoftQ
+    loftq_config = None,  # And LoftQ
 )
 
 from unsloth import is_bf16_supported
 from unsloth.trainer import UnslothVisionDataCollator
-FastVisionModel.for_training(model) # Enable for training!
+
+FastVisionModel.for_training(model)  # Enable for training!
 model.config.use_cache = False
 
 
@@ -123,28 +129,29 @@ def format_data(sample):
     data_collator = UnslothVisionDataCollator(model, tokenizer),
     train_dataset = train_dataset,
     args = SFTConfig(
-        #per_device_train_batch_size = 4,
-        #gradient_accumulation_steps = 8,
+        # per_device_train_batch_size = 4,
+        # gradient_accumulation_steps = 8,
         per_device_train_batch_size = 2,
         gradient_accumulation_steps = 4,
-        gradient_checkpointing=True,
-        gradient_checkpointing_kwargs = {"use_reentrant": False}, # use reentrant checkpointing
-        max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
-        warmup_ratio=0.03,
-        #num_train_epochs = 2, # Set this instead of max_steps for full training runs
-        max_steps=60,
+        gradient_checkpointing = True,
+        gradient_checkpointing_kwargs = {
+            "use_reentrant": False
+        },  # use reentrant checkpointing
+        max_grad_norm = 0.3,  # max gradient norm based on QLoRA paper
+        warmup_ratio = 0.03,
+        # num_train_epochs = 2, # Set this instead of max_steps for full training runs
+        max_steps = 60,
         learning_rate = 2e-4,
         fp16 = not is_bf16_supported(),
         bf16 = is_bf16_supported(),
         logging_steps = 5,
-        save_strategy="epoch",
+        save_strategy = "epoch",
         optim = "adamw_torch_fused",
         weight_decay = 0.01,
         lr_scheduler_type = "linear",
         seed = 3407,
         output_dir = "unsloth-qwen2.5-vl-32b-french-ocr-checkpoints",
-        report_to = "none",     # For Weights and Biases
-
+        report_to = "none",  # For Weights and Biases
         # You MUST put the below items for vision finetuning:
         remove_unused_columns = False,
         dataset_text_field = "",
@@ -165,11 +172,14 @@ def format_data(sample):
 # benchmark lora model performance
 model_name = "Unsloth lora adapter model"
 FastVisionModel.for_inference(model)
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_lora_model_results")
+avg_wer, avg_cer = ocr_evaluator.evaluate_model(
+    model, tokenizer, eval_dataset, output_dir = "unsloth_lora_model_results"
+)
 ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 ## Merge Model
 
+
 def find_lora_base_model(model_to_inspect):
     current = model_to_inspect
     if hasattr(current, "base_model"):
@@ -177,45 +187,68 @@ def find_lora_base_model(model_to_inspect):
     if hasattr(current, "model"):
         current = current.model
     return current
-pass
+
 
 base = find_lora_base_model(model)
 
 print((base.__class__.__name__))
 
 # merge default 16 bits
-model.save_pretrained_merged(save_directory="qwen2.5-ocr-merged-finetune-merge-16bit", tokenizer=tokenizer)
+model.save_pretrained_merged(
+    save_directory = "qwen2.5-ocr-merged-finetune-merge-16bit", tokenizer = tokenizer
+)
 
 
 ## Benchmark merged model performance
 
 ### 16 bits merged model
 
-model, tokenizer = FastVisionModel.from_pretrained("./qwen2.5-ocr-merged-finetune-merge-16bit",load_in_4bit=False, load_in_8bit=False)
+model, tokenizer = FastVisionModel.from_pretrained(
+    "./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = False
+)
 
 # benchmark 4bit loaded, 16bits merged model performance
 model_name = "Unsloth 16bits-merged model load-16bits"
 model.config.use_cache = True
 
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_16bits_results")
+avg_wer, avg_cer = ocr_evaluator.evaluate_model(
+    model,
+    tokenizer,
+    eval_dataset,
+    output_dir = "unsloth_16bits_merged_model_load_16bits_results",
+)
 ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 # load 16bits-merged model in 4 bits
-model, tokenizer = FastVisionModel.from_pretrained("./qwen2.5-ocr-merged-finetune-merge-16bit",load_in_4bit=True, load_in_8bit=False)
+model, tokenizer = FastVisionModel.from_pretrained(
+    "./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = True, load_in_8bit = False
+)
 
 # benchmark 4bit loaded, 16bits merged model performance
 model_name = "Unsloth 16bits-merged model load-4bits"
 model.config.use_cache = True
 
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_4bits_results")
+avg_wer, avg_cer = ocr_evaluator.evaluate_model(
+    model,
+    tokenizer,
+    eval_dataset,
+    output_dir = "unsloth_16bits_merged_model_load_4bits_results",
+)
 ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 # load model in 8 bits
-model, tokenizer = FastVisionModel.from_pretrained("./qwen2.5-ocr-merged-finetune-merge-16bit",load_in_4bit=False, load_in_8bit=True)
+model, tokenizer = FastVisionModel.from_pretrained(
+    "./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = True
+)
 
 # benchmark 4bit loaded, 16bits merged model performance
 model_name = "Unsloth 16bits-merged model load-8bits"
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_8bits_results")
+avg_wer, avg_cer = ocr_evaluator.evaluate_model(
+    model,
+    tokenizer,
+    eval_dataset,
+    output_dir = "unsloth_16bits_merged_model_load_8bits_results",
+)
 ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 # """### 4 bits merged model"""
@@ -239,11 +272,10 @@ def find_lora_base_model(model_to_inspect):
 # ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 # Model comparison report
-#print model comparison
+# print model comparison
 ocr_evaluator.print_model_comparison()
 
 
-
 # Final cleanup
 print("\nš§¹ Cleaning up temporary files...")
 safe_remove_directory("./unsloth-qwen2.5-vl-32b-french-ocr-adapter")
diff --git a/tests/saving/vision_models/test_save_merge_vision_model_ocr_benchmark.py b/tests/saving/vision_models/test_save_merge_vision_model_ocr_benchmark.py
index e556f6246..b99785bcb 100644
--- a/tests/saving/vision_models/test_save_merge_vision_model_ocr_benchmark.py
+++ b/tests/saving/vision_models/test_save_merge_vision_model_ocr_benchmark.py
@@ -22,38 +22,42 @@
 ## Dataset Preparation
 from datasets import load_dataset
 
-dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", 'en', split="train")
+dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
 # To select the first 2000 examples
 train_dataset = dataset.select(range(2000))
 
 # To select the next 200 examples for evaluation
 eval_dataset = dataset.select(range(2000, 2200))
 
+
 # Convert dataset to OAI messages
 def format_data(sample):
-    return {"messages": [
-                {
-                    "role": "system",
-                    "content": [{"type": "text", "text": system_message}],
-                },
-                {
-                    "role": "user",
-                    "content": [
-                        {
-                            "type": "text",
-                            "text": sample["question"],
-                        },{
-                            "type": "image",
-                            "image": sample["image"],
-                        }
-                    ],
-                },
-                {
-                    "role": "assistant",
-                    "content": [{"type": "text", "text": sample["answer"]}],
-                },
-            ],
-        }
+    return {
+        "messages": [
+            {
+                "role": "system",
+                "content": [{"type": "text", "text": system_message}],
+            },
+            {
+                "role": "user",
+                "content": [
+                    {
+                        "type": "text",
+                        "text": sample["question"],
+                    },
+                    {
+                        "type": "image",
+                        "image": sample["image"],
+                    },
+                ],
+            },
+            {
+                "role": "assistant",
+                "content": [{"type": "text", "text": sample["answer"]}],
+            },
+        ],
+    }
+
 
 system_message = "You are an expert french ocr system."
 # Convert dataset to OAI messages
@@ -78,42 +82,44 @@ def format_data(sample):
 
 model, tokenizer = FastVisionModel.from_pretrained(
     model_name = "unsloth/Qwen2-VL-7B-Instruct",
-    max_seq_length = 2048, # Choose any for long context!
+    max_seq_length = 2048,  # Choose any for long context!
     load_in_4bit = True,  # 4 bit quantization to reduce memory
-    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
-    full_finetuning = False, # [NEW!] We have full finetuning now!
+    load_in_8bit = False,  # [NEW!] A bit more accurate, uses 2x memory
+    full_finetuning = False,  # [NEW!] We have full finetuning now!
 )
 
 # benchmark base model performance
 model_name = "Unsloth Base model"
 FastVisionModel.for_inference(model)
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_base_model_results")
+avg_wer, avg_cer = ocr_evaluator.evaluate_model(
+    model, tokenizer, eval_dataset, output_dir = "unsloth_base_model_results"
+)
 ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 ## Lora Finetuning
 model = FastVisionModel.get_peft_model(
     model,
-    finetune_vision_layers     = True, # Turn off for just text!
-    finetune_language_layers   = True,  # Should leave on!
+    finetune_vision_layers = True,  # Turn off for just text!
+    finetune_language_layers = True,  # Should leave on!
     finetune_attention_modules = True,  # Attention good for GRPO
-    finetune_mlp_modules       = True,  # SHould leave on always!
-
-    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
-    #target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
-                      #"gate_proj", "up_proj", "down_proj",],
+    finetune_mlp_modules = True,  # SHould leave on always!
+    r = 16,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
+    # target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
+    # "gate_proj", "up_proj", "down_proj",],
     lora_alpha = 32,
-    lora_dropout = 0, # Supports any, but = 0 is optimized
-    bias = "none",    # Supports any, but = "none" is optimized
+    lora_dropout = 0,  # Supports any, but = 0 is optimized
+    bias = "none",  # Supports any, but = "none" is optimized
     # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
-    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+    use_gradient_checkpointing = "unsloth",  # True or "unsloth" for very long context
     random_state = 3407,
     use_rslora = False,  # We support rank stabilized LoRA
-    loftq_config = None, # And LoftQ
+    loftq_config = None,  # And LoftQ
 )
 
 from unsloth import is_bf16_supported
 from unsloth.trainer import UnslothVisionDataCollator
-FastVisionModel.for_training(model) # Enable for training!
+
+FastVisionModel.for_training(model)  # Enable for training!
 model.config.use_cache = False
 
 
@@ -123,28 +129,29 @@ def format_data(sample):
     data_collator = UnslothVisionDataCollator(model, tokenizer),
     train_dataset = train_dataset,
     args = SFTConfig(
-        #per_device_train_batch_size = 4,
-        #gradient_accumulation_steps = 8,
+        # per_device_train_batch_size = 4,
+        # gradient_accumulation_steps = 8,
         per_device_train_batch_size = 2,
         gradient_accumulation_steps = 4,
-        gradient_checkpointing=True,
-        gradient_checkpointing_kwargs = {"use_reentrant": False}, # use reentrant checkpointing
-        max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
-        warmup_ratio=0.03,
-        #num_train_epochs = 2, # Set this instead of max_steps for full training runs
-        max_steps=60,
+        gradient_checkpointing = True,
+        gradient_checkpointing_kwargs = {
+            "use_reentrant": False
+        },  # use reentrant checkpointing
+        max_grad_norm = 0.3,  # max gradient norm based on QLoRA paper
+        warmup_ratio = 0.03,
+        # num_train_epochs = 2, # Set this instead of max_steps for full training runs
+        max_steps = 60,
         learning_rate = 2e-4,
         fp16 = not is_bf16_supported(),
         bf16 = is_bf16_supported(),
         logging_steps = 5,
-        save_strategy="epoch",
+        save_strategy = "epoch",
         optim = "adamw_torch_fused",
         weight_decay = 0.01,
         lr_scheduler_type = "linear",
         seed = 3407,
         output_dir = "unsloth-qwen2-7vl-french-ocr-checkpoints",
-        report_to = "none",     # For Weights and Biases
-
+        report_to = "none",  # For Weights and Biases
         # You MUST put the below items for vision finetuning:
         remove_unused_columns = False,
         dataset_text_field = "",
@@ -165,11 +172,14 @@ def format_data(sample):
 # benchmark lora model performance
 model_name = "Unsloth lora adapter model"
 FastVisionModel.for_inference(model)
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_lora_model_results")
+avg_wer, avg_cer = ocr_evaluator.evaluate_model(
+    model, tokenizer, eval_dataset, output_dir = "unsloth_lora_model_results"
+)
 ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 ## Merge Model
 
+
 def find_lora_base_model(model_to_inspect):
     current = model_to_inspect
     if hasattr(current, "base_model"):
@@ -177,45 +187,68 @@ def find_lora_base_model(model_to_inspect):
     if hasattr(current, "model"):
         current = current.model
     return current
-pass
+
 
 base = find_lora_base_model(model)
 
 print((base.__class__.__name__))
 
 # merge default 16 bits
-model.save_pretrained_merged(save_directory="qwen2-ocr-merged-finetune-merge-16bit", tokenizer=tokenizer)
+model.save_pretrained_merged(
+    save_directory = "qwen2-ocr-merged-finetune-merge-16bit", tokenizer = tokenizer
+)
 
 
 ## Benchmark merged model performance
 
 ### 16 bits merged model
 
-model, tokenizer = FastVisionModel.from_pretrained("./qwen2-ocr-merged-finetune-merge-16bit",load_in_4bit=False, load_in_8bit=False)
+model, tokenizer = FastVisionModel.from_pretrained(
+    "./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = False
+)
 
 # benchmark 4bit loaded, 16bits merged model performance
 model_name = "Unsloth 16bits-merged model load-16bits"
 model.config.use_cache = True
 
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_16bits_results")
+avg_wer, avg_cer = ocr_evaluator.evaluate_model(
+    model,
+    tokenizer,
+    eval_dataset,
+    output_dir = "unsloth_16bits_merged_model_load_16bits_results",
+)
 ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 # load 16bits-merged model in 4 bits
-model, tokenizer = FastVisionModel.from_pretrained("./qwen2-ocr-merged-finetune-merge-16bit",load_in_4bit=True, load_in_8bit=False)
+model, tokenizer = FastVisionModel.from_pretrained(
+    "./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = True, load_in_8bit = False
+)
 
 # benchmark 4bit loaded, 16bits merged model performance
 model_name = "Unsloth 16bits-merged model load-4bits"
 model.config.use_cache = True
 
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_4bits_results")
+avg_wer, avg_cer = ocr_evaluator.evaluate_model(
+    model,
+    tokenizer,
+    eval_dataset,
+    output_dir = "unsloth_16bits_merged_model_load_4bits_results",
+)
 ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 # load model in 8 bits
-model, tokenizer = FastVisionModel.from_pretrained("./qwen2-ocr-merged-finetune-merge-16bit",load_in_4bit=False, load_in_8bit=True)
+model, tokenizer = FastVisionModel.from_pretrained(
+    "./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = True
+)
 
 # benchmark 4bit loaded, 16bits merged model performance
 model_name = "Unsloth 16bits-merged model load-8bits"
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_8bits_results")
+avg_wer, avg_cer = ocr_evaluator.evaluate_model(
+    model,
+    tokenizer,
+    eval_dataset,
+    output_dir = "unsloth_16bits_merged_model_load_8bits_results",
+)
 ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 # """### 4 bits merged model"""
@@ -239,11 +272,10 @@ def find_lora_base_model(model_to_inspect):
 # ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
 
 # Model comparison report
-#print model comparison
+# print model comparison
 ocr_evaluator.print_model_comparison()
 
 
-
 # Final cleanup
 print("\nš§¹ Cleaning up temporary files...")
 safe_remove_directory("./unsloth-qwen2-7vl-french-ocr-adapter")
diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py
index f59f4f0da..fb9a734c6 100644
--- a/tests/test_model_registry.py
+++ b/tests/test_model_registry.py
@@ -64,16 +64,16 @@ def _test_model_uploaded(model_ids: list[str]):
 
 
 # Test that model registration methods register respective models
-@pytest.mark.parametrize("model_test_param", TestParams, ids=lambda param: param.name)
+@pytest.mark.parametrize("model_test_param", TestParams, ids = lambda param: param.name)
 def test_model_registration(model_test_param: ModelTestParam):
     MODEL_REGISTRY.clear()
     registration_method = model_test_param.register_models
     registration_method()
     registered_models = MODEL_REGISTRY.keys()
     missing_models = _test_model_uploaded(registered_models)
-    assert not missing_models, (
-        f"{model_test_param.name} missing following models: {missing_models}"
-    )
+    assert (
+        not missing_models
+    ), f"{model_test_param.name} missing following models: {missing_models}"
 
 
 def test_all_model_registration():
@@ -82,10 +82,11 @@ def test_all_model_registration():
     missing_models = _test_model_uploaded(registered_models)
     assert not missing_models, f"Missing following models: {missing_models}"
 
+
 def test_quant_type():
     # Test that the quant_type is correctly set for model paths
     # NOTE: for models registered under org="unsloth" with QuantType.NONE aliases QuantType.UNSLOTH
-    dynamic_quant_models = search_models(quant_types=[QuantType.UNSLOTH])
+    dynamic_quant_models = search_models(quant_types = [QuantType.UNSLOTH])
     assert all(m.quant_type == QuantType.UNSLOTH for m in dynamic_quant_models)
     quant_tag = QUANT_TAG_MAP[QuantType.UNSLOTH]
-    assert all(quant_tag in m.model_path for m in dynamic_quant_models)
\ No newline at end of file
+    assert all(quant_tag in m.model_path for m in dynamic_quant_models)
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
index cd5d0d96c..3ad7a4f0f 100644
--- a/tests/utils/__init__.py
+++ b/tests/utils/__init__.py
@@ -25,7 +25,7 @@ def timer(name):
 
 
 @contextmanager
-def header_footer_context(title: str, char="-"):
+def header_footer_context(title: str, char = "-"):
     print()
     print(f"{char}" * 50 + f" {title} " + f"{char}" * 50)
     yield
diff --git a/tests/utils/aime_eval.py b/tests/utils/aime_eval.py
index 54b0d8e51..d70dc3705 100644
--- a/tests/utils/aime_eval.py
+++ b/tests/utils/aime_eval.py
@@ -21,10 +21,10 @@ def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str:
     datasets = {
         "test2024": "https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2024.jsonl",
         "test2025-I": "https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2025-I.jsonl",
-        "test2025-II": "https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2025-II.jsonl"
+        "test2025-II": "https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2025-II.jsonl",
     }
 
-    os.makedirs(data_dir, exist_ok=True)
+    os.makedirs(data_dir, exist_ok = True)
     combined_filepath = os.path.join(data_dir, "aime.jsonl")
 
     # Check if combined file already exists
@@ -45,18 +45,20 @@ def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str:
             response.raise_for_status()
 
             # Parse each line and add source information
-            for line_num, line in enumerate(response.text.strip().split('\n')):
+            for line_num, line in enumerate(response.text.strip().split("\n")):
                 if line.strip():
                     try:
                         data = json.loads(line)
                         # Add source dataset information and global ID
-                        data['source_dataset'] = dataset_name
-                        data['original_id'] = data.get('id', line_num)
-                        data['global_id'] = global_id
+                        data["source_dataset"] = dataset_name
+                        data["original_id"] = data.get("id", line_num)
+                        data["global_id"] = global_id
                         global_id += 1
                         all_problems.append(data)
                     except json.JSONDecodeError as e:
-                        print(f"    Warning: Error parsing line {line_num + 1} in {dataset_name}: {e}")
+                        print(
+                            f"    Warning: Error parsing line {line_num + 1} in {dataset_name}: {e}"
+                        )
                         continue
 
         except requests.RequestException as e:
@@ -65,16 +67,16 @@ def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str:
 
     # Write combined dataset
     if all_problems:
-        with open(combined_filepath, 'w', encoding='utf-8') as f:
+        with open(combined_filepath, "w", encoding = "utf-8") as f:
             for problem in all_problems:
-                f.write(json.dumps(problem, ensure_ascii=False) + '\n')
+                f.write(json.dumps(problem, ensure_ascii = False) + "\n")
 
         print(f"ā
 Combined {len(all_problems)} problems from {len(datasets)} datasets")
         print(f"   Saved to: {combined_filepath}")
 
         # Print summary by dataset
         for dataset_name in datasets.keys():
-            count = sum(1 for p in all_problems if p['source_dataset'] == dataset_name)
+            count = sum(1 for p in all_problems if p["source_dataset"] == dataset_name)
             print(f"   {dataset_name}: {count} problems")
 
     else:
@@ -90,7 +92,7 @@ def load_aime_dataset(data_dir: str = "./data/aime") -> List[Dict[str, Any]]:
     filepath = download_and_combine_aime_datasets(data_dir)
 
     examples = []
-    with open(filepath, 'r', encoding='utf-8') as f:
+    with open(filepath, "r", encoding = "utf-8") as f:
         for line_num, line in enumerate(f):
             line = line.strip()
             if line:
@@ -100,7 +102,9 @@ def load_aime_dataset(data_dir: str = "./data/aime") -> List[Dict[str, Any]]:
                     # Format as expected by our evaluation
                     formatted_example = {
                         "global_id": data.get("global_id", line_num),
-                        "original_id": data.get("original_id", data.get("id", line_num)),
+                        "original_id": data.get(
+                            "original_id", data.get("id", line_num)
+                        ),
                         "source_dataset": data.get("source_dataset", "unknown"),
                         "problem": data["problem"],
                         "answer": str(data["answer"]),  # Ensure answer is string
@@ -108,9 +112,15 @@ def load_aime_dataset(data_dir: str = "./data/aime") -> List[Dict[str, Any]]:
                         "url": data.get("url", ""),
                         # Format as chat messages for the model
                         "prompt": [
-                            {"role": "system", "content": "You are a mathematical problem solver. Solve the given problem step by step and provide your final answer clearly."},
-                            {"role": "user", "content": f"Problem: {data['problem']}\n\nSolve this step by step and provide your final numerical answer."}
-                        ]
+                            {
+                                "role": "system",
+                                "content": "You are a mathematical problem solver. Solve the given problem step by step and provide your final answer clearly.",
+                            },
+                            {
+                                "role": "user",
+                                "content": f"Problem: {data['problem']}\n\nSolve this step by step and provide your final numerical answer.",
+                            },
+                        ],
                     }
                     examples.append(formatted_example)
 
@@ -123,7 +133,7 @@ def load_aime_dataset(data_dir: str = "./data/aime") -> List[Dict[str, Any]]:
     # Print breakdown by source
     source_counts = {}
     for example in examples:
-        source = example['source_dataset']
+        source = example["source_dataset"]
         source_counts[source] = source_counts.get(source, 0) + 1
 
     for source, count in source_counts.items():
@@ -161,7 +171,7 @@ def extract_aime_answer(response: str) -> str:
                 continue
 
     # If no clear pattern found, try to extract any 1-3 digit number
-    numbers = re.findall(r'\b(\d{1,3})\b', response)
+    numbers = re.findall(r"\b(\d{1,3})\b", response)
     if numbers:
         for num_str in reversed(numbers):  # Check from end
             try:
@@ -178,18 +188,27 @@ def get_num_tokens(text, tokenizer_instance):
     """Count tokens in text"""
     if not text:
         return 0
-    encoding = tokenizer_instance(text, return_tensors="pt")
+    encoding = tokenizer_instance(text, return_tensors = "pt")
     return len(encoding["input_ids"][0])
 
 
-def evaluate_model_aime(model, tokenizer, model_type="base", lora_request=None,
-                       temperature=0.3, n_sampling=8, max_tokens=32768, top_p=0.95, seed=0):
+def evaluate_model_aime(
+    model,
+    tokenizer,
+    model_type = "base",
+    lora_request = None,
+    temperature = 0.3,
+    n_sampling = 8,
+    max_tokens = 32768,
+    top_p = 0.95,
+    seed = 0,
+):
     """Evaluate model on combined AIME dataset with official configuration"""
 
-    print(f"\n{'='*70}")
+    print(f"\n{'=' * 70}")
     print(f"š§® AIME EVALUATION - {model_type.upper()} MODEL")
     print(f"Combined Dataset: test2024 + test2025-I + test2025-II")
-    print(f"{'='*70}")
+    print(f"{'=' * 70}")
 
     # Load combined AIME dataset
     try:
@@ -211,18 +230,18 @@ def evaluate_model_aime(model, tokenizer, model_type="base", lora_request=None,
     # Track performance by source dataset
     source_stats = {}
     for example in eval_dataset:
-        source = example['source_dataset']
+        source = example["source_dataset"]
         if source not in source_stats:
-            source_stats[source] = {'total': 0, 'correct': 0}
-        source_stats[source]['total'] += 1
+            source_stats[source] = {"total": 0, "correct": 0}
+        source_stats[source]["total"] += 1
 
     # Setup sampling parameters (AIME configuration)
     sampling_params = SamplingParams(
-        temperature=temperature,
-        top_p=top_p,
-        max_tokens=max_tokens,
-        n=n_sampling,  # Multiple samples per question
-        seed=seed,
+        temperature = temperature,
+        top_p = top_p,
+        max_tokens = max_tokens,
+        n = n_sampling,  # Multiple samples per question
+        seed = seed,
     )
 
     print(f"\nš§ Configuration:")
@@ -234,7 +253,14 @@ def evaluate_model_aime(model, tokenizer, model_type="base", lora_request=None,
 
     # Temporarily suppress verbose logging
     original_levels = {}
-    loggers_to_suppress = ['vllm', 'vllm.engine', 'vllm.worker', 'vllm.model_executor', 'vllm.executor', 'ray']
+    loggers_to_suppress = [
+        "vllm",
+        "vllm.engine",
+        "vllm.worker",
+        "vllm.model_executor",
+        "vllm.executor",
+        "ray",
+    ]
 
     for logger_name in loggers_to_suppress:
         logger = logging.getLogger(logger_name)
@@ -245,14 +271,14 @@ def evaluate_model_aime(model, tokenizer, model_type="base", lora_request=None,
         print(f"\nš Evaluating {len(eval_dataset)} problems...")
 
         # Main evaluation loop
-        with tqdm(total=len(eval_dataset), desc="Processing AIME problems", unit="problem") as pbar:
+        with tqdm(
+            total = len(eval_dataset), desc = "Processing AIME problems", unit = "problem"
+        ) as pbar:
             for task_id, item in enumerate(eval_dataset):
                 try:
                     # Prepare prompt
                     prompt_text = tokenizer.apply_chat_template(
-                        item["prompt"],
-                        add_generation_prompt=True,
-                        tokenize=False
+                        item["prompt"], add_generation_prompt = True, tokenize = False
                     )
 
                     input_tokens.append(get_num_tokens(prompt_text, tokenizer))
@@ -260,27 +286,33 @@ def evaluate_model_aime(model, tokenizer, model_type="base", lora_request=None,
                     # Generate multiple responses
                     outputs = model.fast_generate(
                         [prompt_text],
-                        sampling_params=sampling_params,
-                        lora_request=lora_request,
-                        use_tqdm=False,
+                        sampling_params = sampling_params,
+                        lora_request = lora_request,
+                        use_tqdm = False,
                     )[0].outputs
 
                     # Process all generated responses
                     responses = [output.text for output in outputs]
-                    extracted_answers = [extract_aime_answer(response) for response in responses]
+                    extracted_answers = [
+                        extract_aime_answer(response) for response in responses
+                    ]
 
                     # Calculate total output tokens
-                    total_output_tokens = sum(get_num_tokens(response, tokenizer) for response in responses)
+                    total_output_tokens = sum(
+                        get_num_tokens(response, tokenizer) for response in responses
+                    )
                     output_tokens.append(total_output_tokens)
 
                     # Check if any answer is correct
                     ground_truth = item["answer"]
-                    correct_responses = [ans == ground_truth for ans in extracted_answers]
+                    correct_responses = [
+                        ans == ground_truth for ans in extracted_answers
+                    ]
                     is_correct = any(correct_responses)
 
                     if is_correct:
                         correct_answers += 1
-                        source_stats[item['source_dataset']]['correct'] += 1
+                        source_stats[item["source_dataset"]]["correct"] += 1
 
                     # Store detailed record
                     records[task_id] = {
@@ -298,16 +330,18 @@ def evaluate_model_aime(model, tokenizer, model_type="base", lora_request=None,
                         "n_correct": sum(correct_responses),
                         "n_total": len(responses),
                         "solution": item.get("solution", ""),
-                        "url": item.get("url", "")
+                        "url": item.get("url", ""),
                     }
 
                     # Update progress
                     current_accuracy = correct_answers / (task_id + 1) * 100
-                    pbar.set_postfix({
-                        'accuracy': f'{current_accuracy:.1f}%',
-                        'correct': correct_answers,
-                        'total': task_id + 1
-                    })
+                    pbar.set_postfix(
+                        {
+                            "accuracy": f"{current_accuracy:.1f}%",
+                            "correct": correct_answers,
+                            "total": task_id + 1,
+                        }
+                    )
                     pbar.update(1)
 
                 except Exception as e:
@@ -319,7 +353,7 @@ def evaluate_model_aime(model, tokenizer, model_type="base", lora_request=None,
                         "problem": item["problem"],
                         "ground_truth": item["answer"],
                         "error": str(e),
-                        "is_correct": False
+                        "is_correct": False,
                     }
                     pbar.update(1)
                     continue
@@ -349,7 +383,9 @@ def evaluate_model_aime(model, tokenizer, model_type="base", lora_request=None,
     # Calculate per-source accuracies
     source_accuracies = {}
     for source, stats in source_stats.items():
-        source_accuracies[source] = (stats['correct'] / stats['total'] * 100) if stats['total'] > 0 else 0
+        source_accuracies[source] = (
+            (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0
+        )
 
     results = {
         "model_type": model_type,
@@ -365,31 +401,39 @@ def evaluate_model_aime(model, tokenizer, model_type="base", lora_request=None,
         "max_tokens": max_tokens,
         "top_p": top_p,
         "seed": seed,
-        "avg_input_tokens": sum(input_tokens) / len(input_tokens) if input_tokens else 0,
-        "avg_output_tokens": sum(output_tokens) / len(output_tokens) if output_tokens else 0,
+        "avg_input_tokens": sum(input_tokens) / len(input_tokens)
+        if input_tokens
+        else 0,
+        "avg_output_tokens": sum(output_tokens) / len(output_tokens)
+        if output_tokens
+        else 0,
         "max_input_tokens": max(input_tokens) if input_tokens else 0,
         "max_output_tokens": max(output_tokens) if output_tokens else 0,
     }
 
     # Save results
     filename = f"aime_eval_combined_{model_type}_t{temperature}_n{n_sampling}.json"
-    with open(filename, "w", encoding="utf-8") as f:
-        json.dump({"results": results, "records": records}, f, indent=4)
+    with open(filename, "w", encoding = "utf-8") as f:
+        json.dump({"results": results, "records": records}, f, indent = 4)
 
     # Print comprehensive summary
-    print(f"\n{'='*70}")
+    print(f"\n{'=' * 70}")
     print(f"š AIME EVALUATION RESULTS - {model_type.upper()}")
-    print(f"{'='*70}")
+    print(f"{'=' * 70}")
 
     print(f"\nšÆ Overall Performance:")
     print(f"   Total problems:       {total_problems:>6}")
-    print(f"   Correct answers:      {correct_answers:>6}/{total_problems} ({accuracy:>5.1f}%)")
+    print(
+        f"   Correct answers:      {correct_answers:>6}/{total_problems} ({accuracy:>5.1f}%)"
+    )
     print(f"   Pass@{n_sampling}:              {pass_at_k:>10.1f}%")
 
     print(f"\nš Performance by Dataset:")
     for source, stats in source_stats.items():
         source_acc = source_accuracies[source]
-        print(f"   {source:>12}: {stats['correct']:>3}/{stats['total']:>3} ({source_acc:>5.1f}%)")
+        print(
+            f"   {source:>12}: {stats['correct']:>3}/{stats['total']:>3} ({source_acc:>5.1f}%)"
+        )
 
     print(f"\nš§ Configuration:")
     print(f"   Temperature:          {temperature}")
@@ -420,7 +464,7 @@ def evaluate_model_aime(model, tokenizer, model_type="base", lora_request=None,
 
     print(f"\nšļø  AIME Performance:     {tier} ({accuracy:.1f}%)")
     print(f"\nš¾ Detailed results saved to: {filename}")
-    print(f"\n{'='*70}")
+    print(f"\n{'=' * 70}")
 
     return results
 
@@ -428,68 +472,74 @@ def evaluate_model_aime(model, tokenizer, model_type="base", lora_request=None,
 # Comparison functions for multiple model results
 def compare_aime_results(all_results):
     """Generate comprehensive comparison for AIME evaluation results"""
-    print(f"\n{'='*80}")
+    print(f"\n{'=' * 80}")
     print("COMPREHENSIVE AIME MODEL COMPARISON")
-    print(f"{'='*80}")
+    print(f"{'=' * 80}")
 
     # Main comparison table
-    print(f"{'Model':<15} {'Accuracy %':<12} {'Pass@K %':<10} {'Correct':<8} {'Total':<8}")
+    print(
+        f"{'Model':<15} {'Accuracy %':<12} {'Pass@K %':<10} {'Correct':<8} {'Total':<8}"
+    )
     print("-" * 80)
 
     for result in all_results:
-        print(f"{result['model_type']:<15} "
-              f"{result['accuracy']:<12.1f} "
-              f"{result['pass_at_k']:<10.1f} "
-              f"{result['correct_answers']:<8} "
-              f"{result['total_problems']:<8}")
+        print(
+            f"{result['model_type']:<15} "
+            f"{result['accuracy']:<12.1f} "
+            f"{result['pass_at_k']:<10.1f} "
+            f"{result['correct_answers']:<8} "
+            f"{result['total_problems']:<8}"
+        )
 
     # Performance improvement analysis
     if len(all_results) > 1:
-        print(f"\n{'='*50}")
+        print(f"\n{'=' * 50}")
         print("IMPROVEMENT ANALYSIS")
-        print(f"{'='*50}")
+        print(f"{'=' * 50}")
 
         base_result = all_results[0]  # Assume first is base model
 
         for i, result in enumerate(all_results[1:], 1):
             print(f"\n{result['model_type']} vs {base_result['model_type']}:")
 
-            accuracy_improvement = result['accuracy'] - base_result['accuracy']
-            pass_k_improvement = result['pass_at_k'] - base_result['pass_at_k']
+            accuracy_improvement = result["accuracy"] - base_result["accuracy"]
+            pass_k_improvement = result["pass_at_k"] - base_result["pass_at_k"]
 
             print(f"  Accuracy improvement:  {accuracy_improvement:+.1f}%")
             print(f"  Pass@K improvement:    {pass_k_improvement:+.1f}%")
 
     # Dataset breakdown
-    print(f"\n{'='*50}")
+    print(f"\n{'=' * 50}")
     print("PERFORMANCE BY DATASET")
-    print(f"{'='*50}")
+    print(f"{'=' * 50}")
 
     # Get all unique datasets from the first result
-    if all_results and 'source_accuracies' in all_results[0]:
-        datasets = list(all_results[0]['source_accuracies'].keys())
+    if all_results and "source_accuracies" in all_results[0]:
+        datasets = list(all_results[0]["source_accuracies"].keys())
 
-        print(f"{'Model':<15}", end="")
+        print(f"{'Model':<15}", end = "")
         for dataset in datasets:
-            print(f"{dataset:<15}", end="")
+            print(f"{dataset:<15}", end = "")
         print()
         print("-" * (15 + 15 * len(datasets)))
 
         for result in all_results:
-            print(f"{result['model_type']:<15}", end="")
+            print(f"{result['model_type']:<15}", end = "")
             for dataset in datasets:
-                accuracy = result['source_accuracies'].get(dataset, 0)
-                print(f"{accuracy:<15.1f}", end="")
+                accuracy = result["source_accuracies"].get(dataset, 0)
+                print(f"{accuracy:<15.1f}", end = "")
             print()
 
     # Save comparison
     comparison_data = {
         "summary": all_results,
-        "best_model": max(all_results, key=lambda x: x['accuracy']),
+        "best_model": max(all_results, key = lambda x: x["accuracy"]),
     }
 
     with open("aime_model_comparison.json", "w") as f:
-        json.dump(comparison_data, f, indent=4)
+        json.dump(comparison_data, f, indent = 4)
 
-    print(f"\nBest performing model: {comparison_data['best_model']['model_type']} "
-          f"({comparison_data['best_model']['accuracy']:.1f}% accuracy)")
+    print(
+        f"\nBest performing model: {comparison_data['best_model']['model_type']} "
+        f"({comparison_data['best_model']['accuracy']:.1f}% accuracy)"
+    )
diff --git a/tests/utils/cleanup_utils.py b/tests/utils/cleanup_utils.py
index e9163b131..4995bc4d7 100644
--- a/tests/utils/cleanup_utils.py
+++ b/tests/utils/cleanup_utils.py
@@ -6,7 +6,8 @@
 import sys
 import warnings
 
-def clear_memory(variables_to_clear=None, verbose=False, clear_all_caches=True):
+
+def clear_memory(variables_to_clear = None, verbose = False, clear_all_caches = True):
     """
     Comprehensive memory clearing for persistent memory leaks.
 
@@ -24,9 +25,18 @@ def clear_memory(variables_to_clear=None, verbose=False, clear_all_caches=True):
     root_level = logging.getLogger().level
 
     if variables_to_clear is None:
-        variables_to_clear = ["inputs", "model", "base_model", "processor", "tokenizer",
-                             "base_processor", "base_tokenizer", "trainer",
-                             "peft_model", "bnb_config"]
+        variables_to_clear = [
+            "inputs",
+            "model",
+            "base_model",
+            "processor",
+            "tokenizer",
+            "base_processor",
+            "base_tokenizer",
+            "trainer",
+            "peft_model",
+            "bnb_config",
+        ]
 
     # 1. Clear LRU caches FIRST (very important for memory leaks)
     if clear_all_caches:
@@ -47,7 +57,7 @@ def clear_memory(variables_to_clear=None, verbose=False, clear_all_caches=True):
     for i in range(3):
         collected = gc.collect()
         if verbose and collected > 0:
-            print(f"GC pass {i+1}: collected {collected} objects")
+            print(f"GC pass {i + 1}: collected {collected} objects")
 
     # 4. CUDA cleanup
     if torch.cuda.is_available():
@@ -65,7 +75,9 @@ def clear_memory(variables_to_clear=None, verbose=False, clear_all_caches=True):
             torch.cuda.reset_accumulated_memory_stats()
 
             # Clear JIT cache
-            if hasattr(torch.jit, '_state') and hasattr(torch.jit._state, '_clear_class_state'):
+            if hasattr(torch.jit, "_state") and hasattr(
+                torch.jit._state, "_clear_class_state"
+            ):
                 torch.jit._state._clear_class_state()
 
             # Force another CUDA cache clear
@@ -77,7 +89,9 @@ def clear_memory(variables_to_clear=None, verbose=False, clear_all_caches=True):
         if verbose:
             mem_after = torch.cuda.memory_allocated() / 1024**3
             mem_reserved = torch.cuda.memory_reserved() / 1024**3
-            print(f"GPU memory - Before: {mem_before:.2f} GB, After: {mem_after:.2f} GB")
+            print(
+                f"GPU memory - Before: {mem_before:.2f} GB, After: {mem_after:.2f} GB"
+            )
             print(f"GPU reserved memory: {mem_reserved:.2f} GB")
             if mem_before > 0:
                 print(f"Memory freed: {mem_before - mem_after:.2f} GB")
@@ -89,17 +103,18 @@ def clear_memory(variables_to_clear=None, verbose=False, clear_all_caches=True):
             logger = logging.getLogger(name)
             logger.setLevel(level)
 
-def clear_all_lru_caches(verbose=True):
+
+def clear_all_lru_caches(verbose = True):
     """Clear all LRU caches in loaded modules."""
     cleared_caches = []
 
     # Modules to skip to avoid warnings
     skip_modules = {
-        'torch.distributed',
-        'torchaudio',
-        'torch._C',
-        'torch.distributed.reduce_op',
-        'torchaudio.backend',
+        "torch.distributed",
+        "torchaudio",
+        "torch._C",
+        "torch.distributed.reduce_op",
+        "torchaudio.backend",
     }
 
     # Create a static list of modules to avoid RuntimeError
@@ -125,7 +140,7 @@ def clear_all_lru_caches(verbose=True):
                         warnings.simplefilter("ignore", DeprecationWarning)
 
                     attr = getattr(module, attr_name)
-                    if hasattr(attr, 'cache_clear'):
+                    if hasattr(attr, "cache_clear"):
                         attr.cache_clear()
                         cleared_caches.append(f"{module_name}.{attr_name}")
                 except Exception:
@@ -135,14 +150,14 @@ def clear_all_lru_caches(verbose=True):
 
     # Method 2: Clear specific known caches
     known_caches = [
-        'transformers.utils.hub.cached_file',
-        'transformers.tokenization_utils_base.get_tokenizer',
-        'torch._dynamo.utils.counters',
+        "transformers.utils.hub.cached_file",
+        "transformers.tokenization_utils_base.get_tokenizer",
+        "torch._dynamo.utils.counters",
     ]
 
     for cache_path in known_caches:
         try:
-            parts = cache_path.split('.')
+            parts = cache_path.split(".")
             module = sys.modules.get(parts[0])
             if module:
                 obj = module
@@ -150,7 +165,7 @@ def clear_all_lru_caches(verbose=True):
                     obj = getattr(obj, part, None)
                     if obj is None:
                         break
-                if obj and hasattr(obj, 'cache_clear'):
+                if obj and hasattr(obj, "cache_clear"):
                     obj.cache_clear()
                     cleared_caches.append(cache_path)
         except Exception:
@@ -162,7 +177,7 @@ def clear_all_lru_caches(verbose=True):
 
 def clear_specific_lru_cache(func):
     """Clear cache for a specific function."""
-    if hasattr(func, 'cache_clear'):
+    if hasattr(func, "cache_clear"):
         func.cache_clear()
         return True
     return False
@@ -180,20 +195,22 @@ def monitor_cache_sizes():
             for attr_name in dir(module):
                 try:
                     attr = getattr(module, attr_name)
-                    if hasattr(attr, 'cache_info'):
+                    if hasattr(attr, "cache_info"):
                         info = attr.cache_info()
-                        cache_info.append({
-                            'function': f"{module_name}.{attr_name}",
-                            'size': info.currsize,
-                            'hits': info.hits,
-                            'misses': info.misses
-                        })
+                        cache_info.append(
+                            {
+                                "function": f"{module_name}.{attr_name}",
+                                "size": info.currsize,
+                                "hits": info.hits,
+                                "misses": info.misses,
+                            }
+                        )
                 except:
                     pass
         except:
             pass
 
-    return sorted(cache_info, key=lambda x: x['size'], reverse=True)
+    return sorted(cache_info, key = lambda x: x["size"], reverse = True)
 
 
 def safe_remove_directory(path):
diff --git a/tests/utils/data_utils.py b/tests/utils/data_utils.py
index 7682fe480..955168843 100644
--- a/tests/utils/data_utils.py
+++ b/tests/utils/data_utils.py
@@ -32,10 +32,10 @@ def create_dataset(tokenizer, num_examples: int = None, messages: list[dict] = N
     dataset = create_instruction_dataset(messages)
 
     def _apply_chat_template(example):
-        chat = tokenizer.apply_chat_template(example["messages"], tokenize=False)
+        chat = tokenizer.apply_chat_template(example["messages"], tokenize = False)
         return {"text": chat}
 
-    dataset = dataset.map(_apply_chat_template, remove_columns="messages")
+    dataset = dataset.map(_apply_chat_template, remove_columns = "messages")
     if num_examples is not None:
         if len(dataset) < num_examples:
             num_repeats = num_examples // len(dataset) + 1
@@ -139,11 +139,11 @@ def get_peft_weights(model):
 
 def describe_peft_weights(model):
     for name, param in get_peft_weights(model).items():
-        yield name, describe_param(param, as_str=True)
+        yield name, describe_param(param, as_str = True)
 
 
 def check_responses(responses: list[str], answer: str, prompt: str = None) -> bool:
-    for i, response in enumerate(responses, start=1):
+    for i, response in enumerate(responses, start = 1):
         if answer in response:
             print(f"\u2713 response {i} contains answer")
         else:
diff --git a/tests/utils/hf_utils.py b/tests/utils/hf_utils.py
index cc5edce02..8ad6d5ad0 100644
--- a/tests/utils/hf_utils.py
+++ b/tests/utils/hf_utils.py
@@ -79,27 +79,27 @@ def generate_responses(
     skip_special_tokens: bool = True,
     dtype: torch.dtype = None,
 ):
-    inputs = [tokenizer(prompt, return_tensors="pt") for _ in range(num_generations)]
+    inputs = [tokenizer(prompt, return_tensors = "pt") for _ in range(num_generations)]
     keys = inputs[0].keys()
     batched_inputs = {
-        key: torch.cat([input[key] for input in inputs], dim=0).to(model.device)
+        key: torch.cat([input[key] for input in inputs], dim = 0).to(model.device)
         for key in keys
     }
 
     if dtype is not None:
-        inference_context = torch.autocast(device_type="cuda", dtype=dtype)
+        inference_context = torch.autocast(device_type = "cuda", dtype = dtype)
     else:
         inference_context = nullcontext()
 
     with inference_context:
         outputs = model.generate(
             **batched_inputs,
-            max_new_tokens=max_new_tokens,
-            do_sample=do_sample,
-            temperature=temperature,
+            max_new_tokens = max_new_tokens,
+            do_sample = do_sample,
+            temperature = temperature,
         )
 
-    responses = tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
+    responses = tokenizer.batch_decode(outputs, skip_special_tokens = skip_special_tokens)
     return responses
 
 
@@ -117,11 +117,11 @@ def sample_responses(
         model,
         tokenizer,
         prompt,
-        temperature=temperature,
-        num_generations=num_generations,
-        max_new_tokens=max_new_tokens,
-        skip_special_tokens=skip_special_tokens,
-        dtype=dtype,
+        temperature = temperature,
+        num_generations = num_generations,
+        max_new_tokens = max_new_tokens,
+        skip_special_tokens = skip_special_tokens,
+        dtype = dtype,
     )
     return responses
 
@@ -136,32 +136,32 @@ def setup_tokenizer(model_name, fixup_funcs: list[Callable] = []):
 def setup_model(
     model_name,
     quantize: bool = True,
-    dtype=torch.bfloat16,
-    peft_config=None,
+    dtype = torch.bfloat16,
+    peft_config = None,
     autocast_adapter: bool = True,
 ):
     if quantize:
         bnb_config = BitsAndBytesConfig(
-            load_in_4bit=True,
-            bnb_4bit_use_double_quant=True,
-            bnb_4bit_quant_type="nf4",
-            bnb_4bit_compute_dtype=dtype,
+            load_in_4bit = True,
+            bnb_4bit_use_double_quant = True,
+            bnb_4bit_quant_type = "nf4",
+            bnb_4bit_compute_dtype = dtype,
         )
     else:
         bnb_config = None
 
     model = AutoModelForCausalLM.from_pretrained(
         model_name,
-        device_map="cuda:0",
-        attn_implementation="sdpa",
-        quantization_config=bnb_config,
-        torch_dtype=dtype,
+        device_map = "cuda:0",
+        attn_implementation = "sdpa",
+        quantization_config = bnb_config,
+        torch_dtype = dtype,
     )
     model = prepare_model_for_kbit_training(model) if quantize else model
 
     if peft_config is not None:
         model = get_peft_model(
-            model, peft_config, autocast_adapter_dtype=autocast_adapter
+            model, peft_config, autocast_adapter_dtype = autocast_adapter
         )
 
     return model
@@ -169,19 +169,19 @@ def setup_model(
 
 def get_peft_config(
     lora_rank,
-    lora_alpha=None,
-    lora_dropout=0.0,
-    bias="none",
-    target_modules="all-linear",
+    lora_alpha = None,
+    lora_dropout = 0.0,
+    bias = "none",
+    target_modules = "all-linear",
 ):
     lora_alpha = lora_alpha or 2 * lora_rank
     peft_config = LoraConfig(
-        lora_alpha=lora_alpha,
-        lora_dropout=lora_dropout,
-        r=lora_rank,
-        bias=bias,
-        target_modules=target_modules,
-        task_type="CAUSAL_LM",
+        lora_alpha = lora_alpha,
+        lora_dropout = lora_dropout,
+        r = lora_rank,
+        bias = bias,
+        target_modules = target_modules,
+        task_type = "CAUSAL_LM",
     )
     return peft_config
 
@@ -191,18 +191,18 @@ def setup_trainer(
     tokenizer,
     dataset,
     train_args,
-    peft_config=None,
-    formatting_func=None,
-    collator=None,
+    peft_config = None,
+    formatting_func = None,
+    collator = None,
 ):
     return SFTTrainer(
-        model=model,
-        peft_config=peft_config,
-        train_dataset=dataset,
-        processing_class=tokenizer,
-        formatting_func=formatting_func,
-        data_collator=collator,
-        args=train_args,
+        model = model,
+        peft_config = peft_config,
+        train_dataset = dataset,
+        processing_class = tokenizer,
+        formatting_func = formatting_func,
+        data_collator = collator,
+        args = train_args,
     )
 
 
@@ -212,17 +212,17 @@ def setup_lora(
     dataset,
     peft_config,
     train_args,
-    formatting_func=None,
-    collator=None,
+    formatting_func = None,
+    collator = None,
 ):
     return LoraConfig(
-        model=model,
-        peft_config=peft_config,
-        train_dataset=dataset,
-        processing_class=tokenizer,
-        formatting_func=formatting_func,
-        data_collator=collator,
-        args=train_args,
+        model = model,
+        peft_config = peft_config,
+        train_dataset = dataset,
+        processing_class = tokenizer,
+        formatting_func = formatting_func,
+        data_collator = collator,
+        args = train_args,
     )
 
 
@@ -236,7 +236,7 @@ def convert_weights_back_to_dtype(model, dtype):
             param.data = param.data.to(dtype)
 
 
-def fix_llama3_tokenizer(tokenizer, padding_side="right"):
+def fix_llama3_tokenizer(tokenizer, padding_side = "right"):
     tokenizer.padding_side = padding_side
     added_vocab = tokenizer.get_added_vocab()
     pad_token = [w for w in added_vocab if "pad" in w]
@@ -276,12 +276,12 @@ def _convert_lora_to_linear(module: LoraLayer, adapter_name: str = "default"):
     w_dq = w_dq.to(original_dtype)
 
     new_module = torch.nn.Linear(
-        w_dq.shape[1], w_dq.shape[0], bias=module.base_layer.bias is not None
+        w_dq.shape[1], w_dq.shape[0], bias = module.base_layer.bias is not None
     )
-    new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad=False)
+    new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad = False)
     if module.lora_bias[adapter_name]:
         bias_data = module.base_layer.bias.data + module.lora_B[adapter_name].bias
-        new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad=False)
+        new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad = False)
     return new_module
 
 
diff --git a/tests/utils/ocr_eval.py b/tests/utils/ocr_eval.py
index 5df476c1e..3c5cd74a2 100644
--- a/tests/utils/ocr_eval.py
+++ b/tests/utils/ocr_eval.py
@@ -35,26 +35,28 @@ def evaluate_model(
         max_new_tokens: int = 1024,
         temperature: float = 1.5,
         min_p: float = 0.1,
-        verbose: bool = True
+        verbose: bool = True,
     ) -> Tuple[Optional[float], Optional[float]]:
         """
         Evaluate a model on an OCR dataset.
         """
         # Create output directory if it doesn't exist
-        os.makedirs(output_dir, exist_ok=True)
+        os.makedirs(output_dir, exist_ok = True)
 
         # Initialize results storage
         results = []
 
         # Process each sample in the dataset
-        for i, sample in enumerate(tqdm(dataset, desc="Evaluating OCR performance", disable=not verbose)):
+        for i, sample in enumerate(
+            tqdm(dataset, desc = "Evaluating OCR performance", disable = not verbose)
+        ):
             try:
                 # Extract components from sample
-                messages = sample['messages']
+                messages = sample["messages"]
 
                 # Get ground truth, image, and question
-                ground_truth, image, question, input_messages = self._extract_sample_components(
-                    messages, i, verbose
+                ground_truth, image, question, input_messages = (
+                    self._extract_sample_components(messages, i, verbose)
                 )
 
                 if ground_truth is None or image is None or question is None:
@@ -71,18 +73,26 @@ def evaluate_model(
 
                 # Save individual result
                 self._save_individual_result(
-                    output_dir, i, question, generated_response, ground_truth, word_error, char_error
+                    output_dir,
+                    i,
+                    question,
+                    generated_response,
+                    ground_truth,
+                    word_error,
+                    char_error,
                 )
 
                 # Store results for summary
-                results.append({
-                    'sample_id': i,
-                    'wer': word_error,
-                    'cer': char_error,
-                    'model_output': generated_response.strip(),
-                    'ground_truth': ground_truth,
-                    'question': question
-                })
+                results.append(
+                    {
+                        "sample_id": i,
+                        "wer": word_error,
+                        "cer": char_error,
+                        "model_output": generated_response.strip(),
+                        "ground_truth": ground_truth,
+                        "question": question,
+                    }
+                )
 
             except Exception as e:
                 if verbose:
@@ -93,51 +103,56 @@ def evaluate_model(
         return self._generate_summary_report(results, output_dir, verbose)
 
     def _extract_sample_components(
-        self,
-        messages: List[Dict],
-        sample_idx: int,
-        verbose: bool
+        self, messages: List[Dict], sample_idx: int, verbose: bool
     ) -> Tuple[Optional[str], Optional[Any], Optional[str], List[Dict]]:
         """Extract ground truth, image, question, and input messages from sample."""
 
         # Extract system message (if present)
-        system_message = next((msg for msg in messages if msg['role'] == 'system'), None)
+        system_message = next(
+            (msg for msg in messages if msg["role"] == "system"), None
+        )
 
         # Extract user message with the image and question
-        user_message = next((msg for msg in messages if msg['role'] == 'user'), None)
+        user_message = next((msg for msg in messages if msg["role"] == "user"), None)
         if not user_message:
             if verbose:
                 print(f"Skipping sample {sample_idx}: No user message found")
             return None, None, None, []
 
         # Extract assistant message with ground truth
-        assistant_message = next((msg for msg in messages if msg['role'] == 'assistant'), None)
+        assistant_message = next(
+            (msg for msg in messages if msg["role"] == "assistant"), None
+        )
         if not assistant_message:
             if verbose:
-                print(f"Skipping sample {sample_idx}: No assistant message (ground truth) found")
+                print(
+                    f"Skipping sample {sample_idx}: No assistant message (ground truth) found"
+                )
             return None, None, None, []
 
         # Extract ground truth text
         ground_truth = None
-        for content_item in assistant_message['content']:
-            if content_item['type'] == 'text':
-                ground_truth = content_item['text']
+        for content_item in assistant_message["content"]:
+            if content_item["type"] == "text":
+                ground_truth = content_item["text"]
                 break
 
         if not ground_truth:
             if verbose:
-                print(f"Skipping sample {sample_idx}: No text found in assistant message")
+                print(
+                    f"Skipping sample {sample_idx}: No text found in assistant message"
+                )
             return None, None, None, []
 
         # Extract image and question from user message
         image = None
         question = None
 
-        for content_item in user_message['content']:
-            if content_item['type'] == 'image':
-                image = content_item['image']
-            elif content_item['type'] == 'text':
-                question = content_item['text']
+        for content_item in user_message["content"]:
+            if content_item["type"] == "image":
+                image = content_item["image"]
+            elif content_item["type"] == "text":
+                question = content_item["text"]
 
         if not image:
             if verbose:
@@ -146,7 +161,9 @@ def _extract_sample_components(
 
         if not question:
             if verbose:
-                print(f"Skipping sample {sample_idx}: No question found in user message")
+                print(
+                    f"Skipping sample {sample_idx}: No question found in user message"
+                )
             return None, None, None, []
 
         # Construct messages for the model input (excluding assistant message)
@@ -164,13 +181,13 @@ def _generate_response(
         input_messages: List[Dict],
         max_new_tokens: int,
         temperature: float,
-        min_p: float
+        min_p: float,
     ) -> str:
         """Generate response from the model."""
 
         # Preparation for inference using Qwen's specific processing
         text = processor.apply_chat_template(
-            input_messages, tokenize=False, add_generation_prompt=True
+            input_messages, tokenize = False, add_generation_prompt = True
         )
 
         # Process vision info (images/videos) from messages
@@ -178,11 +195,11 @@ def _generate_response(
 
         # Create model inputs
         inputs = processor(
-            text=[text],
-            images=image_inputs,
-            videos=video_inputs,
-            padding=True,
-            return_tensors="pt"
+            text = [text],
+            images = image_inputs,
+            videos = video_inputs,
+            padding = True,
+            return_tensors = "pt",
         )
         inputs = inputs.to(model.device)
 
@@ -190,22 +207,23 @@ def _generate_response(
         with torch.no_grad():
             generated_ids = model.generate(
                 **inputs,
-                max_new_tokens=max_new_tokens,
-                temperature=temperature,
-                min_p=min_p,
-                use_cache=True
+                max_new_tokens = max_new_tokens,
+                temperature = temperature,
+                min_p = min_p,
+                use_cache = True,
             )
 
         # Extract only the generated part (not the input)
         generated_ids_trimmed = [
-            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+            out_ids[len(in_ids) :]
+            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
         ]
 
         # Decode the generated text
         generated_response = processor.batch_decode(
             generated_ids_trimmed,
-            skip_special_tokens=True,
-            clean_up_tokenization_spaces=False
+            skip_special_tokens = True,
+            clean_up_tokenization_spaces = False,
         )[0]
 
         return generated_response
@@ -218,11 +236,11 @@ def _save_individual_result(
         generated_response: str,
         ground_truth: str,
         word_error: float,
-        char_error: float
+        char_error: float,
     ):
         """Save individual sample result to file."""
         output_file = os.path.join(output_dir, f"sample_{sample_idx}.txt")
-        with open(output_file, 'w', encoding='utf-8') as f:
+        with open(output_file, "w", encoding = "utf-8") as f:
             f.write(f"Sample {sample_idx}\n")
             f.write(f"Question: {question}\n\n")
             f.write(f"Model output:\n{generated_response.strip()}\n\n")
@@ -230,10 +248,7 @@ def _save_individual_result(
             f.write(f"WER: {word_error:.4f}, CER: {char_error:.4f}")
 
     def _generate_summary_report(
-        self,
-        results: List[Dict],
-        output_dir: str,
-        verbose: bool
+        self, results: List[Dict], output_dir: str, verbose: bool
     ) -> Tuple[Optional[float], Optional[float]]:
         """Generate and save summary report."""
         if not results:
@@ -244,16 +259,16 @@ def _generate_summary_report(
         df = pd.DataFrame(results)
 
         # Calculate overall averages
-        avg_wer = df['wer'].mean()
-        avg_cer = df['cer'].mean()
+        avg_wer = df["wer"].mean()
+        avg_cer = df["cer"].mean()
 
         # Save average metrics
-        with open(os.path.join(output_dir, "avg_metrics.txt"), 'w') as f:
+        with open(os.path.join(output_dir, "avg_metrics.txt"), "w") as f:
             f.write(f"Average WER: {avg_wer:.4f}\n")
             f.write(f"Average CER: {avg_cer:.4f}\n")
 
         # Save detailed results
-        df.to_csv(os.path.join(output_dir, "detailed_results.csv"), index=False)
+        df.to_csv(os.path.join(output_dir, "detailed_results.csv"), index = False)
 
         if verbose:
             print("\nResults Summary:")
@@ -265,12 +280,11 @@ def _generate_summary_report(
 
     def add_to_comparison(self, model_name: str, wer: float, cer: float):
         """Add model results to the comparison tracker."""
-        self.model_comparison_results[model_name] = {
-            "wer": wer,
-            "cer": cer
-        }
+        self.model_comparison_results[model_name] = {"wer": wer, "cer": cer}
 
-    def print_model_comparison(self, save_csv: bool = True, save_plot: bool = True) -> Optional[pd.DataFrame]:
+    def print_model_comparison(
+        self, save_csv: bool = True, save_plot: bool = True
+    ) -> Optional[pd.DataFrame]:
         """Print a comparison of all models evaluated so far."""
         if not self.model_comparison_results:
             print("No model results available for comparison")
@@ -279,23 +293,29 @@ def print_model_comparison(self, save_csv: bool = True, save_plot: bool = True)
         print("\n==== MODEL COMPARISON REPORT ====")
 
         # Create a comparison dataframe
-        comparison_df = pd.DataFrame({
-            "Model": list(self.model_comparison_results.keys()),
-            "WER": [results["wer"] for results in self.model_comparison_results.values()],
-            "CER": [results["cer"] for results in self.model_comparison_results.values()]
-        })
+        comparison_df = pd.DataFrame(
+            {
+                "Model": list(self.model_comparison_results.keys()),
+                "WER": [
+                    results["wer"] for results in self.model_comparison_results.values()
+                ],
+                "CER": [
+                    results["cer"] for results in self.model_comparison_results.values()
+                ],
+            }
+        )
 
         # Sort by WER (best performance first)
         comparison_df = comparison_df.sort_values("WER")
 
         # Display the comparison table
         print("\nComparison Table (sorted by WER):")
-        print(comparison_df.to_string(index=False))
+        print(comparison_df.to_string(index = False))
 
         # Save the comparison table
         if save_csv:
             comparison_file = "model_comparison_results.csv"
-            comparison_df.to_csv(comparison_file, index=False)
+            comparison_df.to_csv(comparison_file, index = False)
             print(f"\nComparison table saved to {comparison_file}")
 
         # Generate a bar chart visualization
@@ -306,26 +326,26 @@ def print_model_comparison(self, save_csv: bool = True, save_plot: bool = True)
 
     def _create_comparison_plot(self, comparison_df: pd.DataFrame):
         """Create and save comparison plot."""
-        plt.figure(figsize=(12, 6))
+        plt.figure(figsize = (12, 6))
 
         # Plot WER
         plt.subplot(1, 2, 1)
-        plt.bar(comparison_df["Model"], comparison_df["WER"], color='skyblue')
-        plt.title('Word Error Rate Comparison')
-        plt.ylabel('WER (lower is better)')
-        plt.ylim(bottom=0)
-        plt.xticks(rotation=45, ha='right')
+        plt.bar(comparison_df["Model"], comparison_df["WER"], color = "skyblue")
+        plt.title("Word Error Rate Comparison")
+        plt.ylabel("WER (lower is better)")
+        plt.ylim(bottom = 0)
+        plt.xticks(rotation = 45, ha = "right")
 
         # Plot CER
         plt.subplot(1, 2, 2)
-        plt.bar(comparison_df["Model"], comparison_df["CER"], color='lightgreen')
-        plt.title('Character Error Rate Comparison')
-        plt.ylabel('CER (lower is better)')
-        plt.ylim(bottom=0)
-        plt.xticks(rotation=45, ha='right')
+        plt.bar(comparison_df["Model"], comparison_df["CER"], color = "lightgreen")
+        plt.title("Character Error Rate Comparison")
+        plt.ylabel("CER (lower is better)")
+        plt.ylim(bottom = 0)
+        plt.xticks(rotation = 45, ha = "right")
 
         plt.tight_layout()
-        plt.savefig('ocr_model_comparison.png')
+        plt.savefig("ocr_model_comparison.png")
         plt.show()
 
         print(f"\nVisualization saved to ocr_model_comparison.png")
@@ -339,7 +359,9 @@ def clear_comparison_results(self):
         self.model_comparison_results.clear()
 
 
-def evaluate_ocr_model(model, processor, dataset, output_dir="ocr_evaluation_results", **kwargs):
+def evaluate_ocr_model(
+    model, processor, dataset, output_dir = "ocr_evaluation_results", **kwargs
+):
     """
     Convenience function that maintains backward compatibility with the original function.
     """
diff --git a/tests/utils/os_utils.py b/tests/utils/os_utils.py
index 281fcdbaf..448f13b8a 100644
--- a/tests/utils/os_utils.py
+++ b/tests/utils/os_utils.py
@@ -4,14 +4,15 @@
 import shutil
 import importlib
 
+
 def detect_package_manager():
     """Detect the available package manager"""
     package_managers = {
-        'apt': '/usr/bin/apt',
-        'yum': '/usr/bin/yum',
-        'dnf': '/usr/bin/dnf',
-        'pacman': '/usr/bin/pacman',
-        'zypper': '/usr/bin/zypper'
+        "apt": "/usr/bin/apt",
+        "yum": "/usr/bin/yum",
+        "dnf": "/usr/bin/dnf",
+        "pacman": "/usr/bin/pacman",
+        "zypper": "/usr/bin/zypper",
     }
 
     for pm, path in package_managers.items():
@@ -19,7 +20,8 @@ def detect_package_manager():
             return pm
     return None
 
-def check_package_installed(package_name, package_manager=None):
+
+def check_package_installed(package_name, package_manager = None):
     """Check if a package is installed using the system package manager"""
 
     if package_manager is None:
@@ -30,33 +32,38 @@ def check_package_installed(package_name, package_manager=None):
         return None
 
     try:
-        if package_manager == 'apt':
+        if package_manager == "apt":
             # Check with dpkg
-            result = subprocess.run(['dpkg', '-l', package_name],
-                                  capture_output=True, text=True)
+            result = subprocess.run(
+                ["dpkg", "-l", package_name], capture_output = True, text = True
+            )
             return result.returncode == 0
 
-        elif package_manager in ['yum', 'dnf']:
+        elif package_manager in ["yum", "dnf"]:
             # Check with rpm
-            result = subprocess.run(['rpm', '-q', package_name],
-                                  capture_output=True, text=True)
+            result = subprocess.run(
+                ["rpm", "-q", package_name], capture_output = True, text = True
+            )
             return result.returncode == 0
 
-        elif package_manager == 'pacman':
-            result = subprocess.run(['pacman', '-Q', package_name],
-                                  capture_output=True, text=True)
+        elif package_manager == "pacman":
+            result = subprocess.run(
+                ["pacman", "-Q", package_name], capture_output = True, text = True
+            )
             return result.returncode == 0
 
-        elif package_manager == 'zypper':
-            result = subprocess.run(['zypper', 'se', '-i', package_name],
-                                  capture_output=True, text=True)
+        elif package_manager == "zypper":
+            result = subprocess.run(
+                ["zypper", "se", "-i", package_name], capture_output = True, text = True
+            )
             return package_name in result.stdout
 
     except Exception as e:
         print(f"Error checking package: {e}")
         return None
 
-def require_package(package_name, executable_name=None):
+
+def require_package(package_name, executable_name = None):
     """Require a package to be installed, exit if not found"""
 
     # First check if executable is in PATH (most reliable)
@@ -78,11 +85,11 @@ def require_package(package_name, executable_name=None):
     print(f"\nPlease install {package_name} using your system package manager:")
 
     install_commands = {
-        'apt': f"sudo apt update && sudo apt install {package_name}",
-        'yum': f"sudo yum install {package_name}",
-        'dnf': f"sudo dnf install {package_name}",
-        'pacman': f"sudo pacman -S {package_name}",
-        'zypper': f"sudo zypper install {package_name}"
+        "apt": f"sudo apt update && sudo apt install {package_name}",
+        "yum": f"sudo yum install {package_name}",
+        "dnf": f"sudo dnf install {package_name}",
+        "pacman": f"sudo pacman -S {package_name}",
+        "zypper": f"sudo zypper install {package_name}",
     }
 
     if pm and pm in install_commands:
@@ -97,10 +104,12 @@ def require_package(package_name, executable_name=None):
     print(f"\nPlease install the required package and run the script again.")
     sys.exit(1)
 
+
 # Usage
-#require_package("ffmpeg", "ffmpeg")
+# require_package("ffmpeg", "ffmpeg")
+
 
-def require_python_package(package_name, import_name=None, pip_name=None):
+def require_python_package(package_name, import_name = None, pip_name = None):
     """Require a Python package to be installed, exit if not found"""
     if import_name is None:
         import_name = package_name
diff --git a/tests/utils/perplexity_eval.py b/tests/utils/perplexity_eval.py
index fa297540f..cf625fc74 100644
--- a/tests/utils/perplexity_eval.py
+++ b/tests/utils/perplexity_eval.py
@@ -3,16 +3,16 @@
 import pandas as pd
 
 model_comparison_results = {}
-#return the perplexity of the model on the dataset
-#The perplexity is computed on each example, individually, with a sliding window for examples longer than 512 tokens.
+# return the perplexity of the model on the dataset
+# The perplexity is computed on each example, individually, with a sliding window for examples longer than 512 tokens.
 
 
 def ppl_model(model, tokenizer, dataset):
     nlls = []
     max_length = 2048
     stride = 512
-    for s in tqdm(range(len(dataset['text']))):
-        encodings = tokenizer(dataset['text'][s], return_tensors="pt")
+    for s in tqdm(range(len(dataset["text"]))):
+        encodings = tokenizer(dataset["text"][s], return_tensors = "pt")
         seq_len = encodings.input_ids.size(1)
         prev_end_loc = 0
         for begin_loc in range(0, seq_len, stride):
@@ -22,10 +22,14 @@ def ppl_model(model, tokenizer, dataset):
             target_ids = input_ids.clone()
             target_ids[:, :-trg_len] = -100
             # Create attention mask based on pad token id
-            pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
+            pad_token_id = (
+                tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
+            )
             attention_mask = (input_ids != pad_token_id).long()
             with torch.no_grad():
-                outputs = model(input_ids, labels=target_ids, attention_mask=attention_mask)
+                outputs = model(
+                    input_ids, labels = target_ids, attention_mask = attention_mask
+                )
                 neg_log_likelihood = outputs.loss
             nlls.append(neg_log_likelihood)
             prev_end_loc = end_loc
@@ -35,19 +39,17 @@ def ppl_model(model, tokenizer, dataset):
     return ppl
 
 
-#--------------------------------------------------------------------
+# --------------------------------------------------------------------
 
 
 ## ----------- Reporting helper function ----------- ##
 
+
 # Create a simple function to add results to the comparison
 def add_to_comparison(model_name, ppl):
     """Add model results to the comparison tracker"""
-    model_comparison_results[model_name] = {
-        "ppl": ppl
-    }
-    #return model_comparison_results
-
+    model_comparison_results[model_name] = {"ppl": ppl}
+    # return model_comparison_results
 
 
 # Create a function to print the comparison report whenever needed
@@ -60,16 +62,20 @@ def print_model_comparison():
     print("\n==== MODEL COMPARISON REPORT ====")
 
     # Create a comparison dataframe
-    comparison_df = pd.DataFrame({
-        "Model": list(model_comparison_results.keys()),
-        #"Perplexity": [results["ppl"] for results in model_comparison_results.values()],
-        "Perplexity": [
-            # Convert tensors to CPU and then to float if needed
-            results["ppl"].cpu().item() if torch.is_tensor(results["ppl"]) else results["ppl"]
-            for results in model_comparison_results.values()
-        ],
-    })
+    comparison_df = pd.DataFrame(
+        {
+            "Model": list(model_comparison_results.keys()),
+            # "Perplexity": [results["ppl"] for results in model_comparison_results.values()],
+            "Perplexity": [
+                # Convert tensors to CPU and then to float if needed
+                results["ppl"].cpu().item()
+                if torch.is_tensor(results["ppl"])
+                else results["ppl"]
+                for results in model_comparison_results.values()
+            ],
+        }
+    )
 
     # Display the comparison table
     print("\nComparison Table:")
-    print(comparison_df.to_string(index=False))
+    print(comparison_df.to_string(index = False))
diff --git a/tests/utils/test_qat.py b/tests/utils/test_qat.py
index f95ec7ea6..79251cf2f 100644
--- a/tests/utils/test_qat.py
+++ b/tests/utils/test_qat.py
@@ -16,6 +16,7 @@ class _CountingFakeQuantizer(torch.nn.Module):
     """
     Dummy fake quantizer that counts the number of times it has been called.
     """
+
     def __init__(self):
         super().__init__()
         self.count = 0
@@ -88,6 +89,7 @@ def _test_fake_quantizers_are_called(
     """
     Verify that the fake quantizers are actually called when the model is called.
     """
+
     def _swap_fake_quantizers(model: torch.nn.Module):
         for name, child in model.named_children():
             if isinstance(child, FakeQuantizerBase):
@@ -138,7 +140,7 @@ def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
         _test_linear_is_fake_quantized(layer.mlp.gate_proj, qat_scheme)
         _test_linear_is_fake_quantized(layer.mlp.up_proj, qat_scheme)
         _test_linear_is_fake_quantized(layer.mlp.down_proj, qat_scheme)
-    inputs = tokenizer("How are you?", return_tensors="pt")
+    inputs = tokenizer("How are you?", return_tensors = "pt")
     _test_fake_quantizers_are_called(model, inputs, full_finetuning)
 
 
@@ -146,9 +148,9 @@ def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
 # how to disable model caching before re-enabling this test
 @pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
 def _test_full_model_fake_quantize(qat_scheme: bool):
-    _test_model_fake_quantize(qat_scheme, full_finetuning=True)
+    _test_model_fake_quantize(qat_scheme, full_finetuning = True)
 
 
 @pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
 def test_lora_model_fake_quantize(qat_scheme: bool):
-    _test_model_fake_quantize(qat_scheme, full_finetuning=False)
+    _test_model_fake_quantize(qat_scheme, full_finetuning = False)
diff --git a/unsloth-cli.py b/unsloth-cli.py
index 7895f95dd..fb6e39266 100644
--- a/unsloth-cli.py
+++ b/unsloth-cli.py
@@ -42,29 +42,37 @@ def run(args):
     from transformers import TrainingArguments
     from unsloth import is_bfloat16_supported
     import logging
-    logging.getLogger('hf-to-gguf').setLevel(logging.WARNING)
+
+    logging.getLogger("hf-to-gguf").setLevel(logging.WARNING)
 
     # Load model and tokenizer
     model, tokenizer = FastLanguageModel.from_pretrained(
-        model_name=args.model_name,
-        max_seq_length=args.max_seq_length,
-        dtype=args.dtype,
-        load_in_4bit=args.load_in_4bit,
+        model_name = args.model_name,
+        max_seq_length = args.max_seq_length,
+        dtype = args.dtype,
+        load_in_4bit = args.load_in_4bit,
     )
 
     # Configure PEFT model
     model = FastLanguageModel.get_peft_model(
         model,
-        r=args.r,
-        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
-                        "gate_proj", "up_proj", "down_proj"],
-        lora_alpha=args.lora_alpha,
-        lora_dropout=args.lora_dropout,
-        bias=args.bias,
-        use_gradient_checkpointing=args.use_gradient_checkpointing,
-        random_state=args.random_state,
-        use_rslora=args.use_rslora,
-        loftq_config=args.loftq_config,
+        r = args.r,
+        target_modules = [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+            "o_proj",
+            "gate_proj",
+            "up_proj",
+            "down_proj",
+        ],
+        lora_alpha = args.lora_alpha,
+        lora_dropout = args.lora_dropout,
+        bias = args.bias,
+        use_gradient_checkpointing = args.use_gradient_checkpointing,
+        random_state = args.random_state,
+        use_rslora = args.use_rslora,
+        loftq_config = args.loftq_config,
     )
 
     alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
@@ -79,53 +87,55 @@ def run(args):
     {}"""
 
     EOS_TOKEN = tokenizer.eos_token  # Must add EOS_TOKEN
+
     def formatting_prompts_func(examples):
         instructions = examples["instruction"]
-        inputs       = examples["input"]
-        outputs      = examples["output"]
+        inputs = examples["input"]
+        outputs = examples["output"]
         texts = []
         for instruction, input, output in zip(instructions, inputs, outputs):
             text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
             texts.append(text)
         return {"text": texts}
 
-    use_modelscope = strtobool(os.environ.get('UNSLOTH_USE_MODELSCOPE', 'False'))
+    use_modelscope = strtobool(os.environ.get("UNSLOTH_USE_MODELSCOPE", "False"))
     if use_modelscope:
         from modelscope import MsDataset
-        dataset = MsDataset.load(args.dataset, split="train")
+
+        dataset = MsDataset.load(args.dataset, split = "train")
     else:
         # Load and format dataset
-        dataset = load_dataset(args.dataset, split="train")
-    dataset = dataset.map(formatting_prompts_func, batched=True)
+        dataset = load_dataset(args.dataset, split = "train")
+    dataset = dataset.map(formatting_prompts_func, batched = True)
     print("Data is formatted and ready!")
 
     # Configure training arguments
     training_args = SFTConfig(
-        per_device_train_batch_size=args.per_device_train_batch_size,
-        gradient_accumulation_steps=args.gradient_accumulation_steps,
-        warmup_steps=args.warmup_steps,
-        max_steps=args.max_steps,
-        learning_rate=args.learning_rate,
-        fp16=not is_bfloat16_supported(),
-        bf16=is_bfloat16_supported(),
-        logging_steps=args.logging_steps,
-        optim=args.optim,
-        weight_decay=args.weight_decay,
-        lr_scheduler_type=args.lr_scheduler_type,
-        seed=args.seed,
-        output_dir=args.output_dir,
-        report_to=args.report_to,
-        max_length=args.max_seq_length,
-        dataset_num_proc=2,
-        packing=False,
+        per_device_train_batch_size = args.per_device_train_batch_size,
+        gradient_accumulation_steps = args.gradient_accumulation_steps,
+        warmup_steps = args.warmup_steps,
+        max_steps = args.max_steps,
+        learning_rate = args.learning_rate,
+        fp16 = not is_bfloat16_supported(),
+        bf16 = is_bfloat16_supported(),
+        logging_steps = args.logging_steps,
+        optim = args.optim,
+        weight_decay = args.weight_decay,
+        lr_scheduler_type = args.lr_scheduler_type,
+        seed = args.seed,
+        output_dir = args.output_dir,
+        report_to = args.report_to,
+        max_length = args.max_seq_length,
+        dataset_num_proc = 2,
+        packing = False,
     )
 
     # Initialize trainer
     trainer = SFTTrainer(
-        model=model,
-        processing_class=tokenizer,
-        train_dataset=dataset,
-        args=training_args,
+        model = model,
+        processing_class = tokenizer,
+        train_dataset = dataset,
+        args = training_args,
     )
 
     # Train model
@@ -137,26 +147,30 @@ def formatting_prompts_func(examples):
         if args.save_gguf:
             if isinstance(args.quantization, list):
                 for quantization_method in args.quantization:
-                    print(f"Saving model with quantization method: {quantization_method}")
+                    print(
+                        f"Saving model with quantization method: {quantization_method}"
+                    )
                     model.save_pretrained_gguf(
                         args.save_path,
                         tokenizer,
-                        quantization_method=quantization_method,
+                        quantization_method = quantization_method,
                     )
                     if args.push_model:
                         model.push_to_hub_gguf(
-                            hub_path=args.hub_path,
-                            hub_token=args.hub_token,
-                            quantization_method=quantization_method,
+                            hub_path = args.hub_path,
+                            hub_token = args.hub_token,
+                            quantization_method = quantization_method,
                         )
             else:
                 print(f"Saving model with quantization method: {args.quantization}")
-                model.save_pretrained_gguf(args.save_path, tokenizer, quantization_method=args.quantization)
+                model.save_pretrained_gguf(
+                    args.save_path, tokenizer, quantization_method = args.quantization
+                )
                 if args.push_model:
                     model.push_to_hub_gguf(
-                        hub_path=args.hub_path,
-                        hub_token=args.hub_token,
-                        quantization_method=quantization_method,
+                        hub_path = args.hub_path,
+                        hub_token = args.hub_token,
+                        quantization_method = quantization_method,
                     )
         else:
             model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
@@ -167,62 +181,213 @@ def formatting_prompts_func(examples):
 
 
 if __name__ == "__main__":
-
     # Define argument parser
-    parser = argparse.ArgumentParser(description="𦄠Fine-tune your llm faster using unsloth!")
+    parser = argparse.ArgumentParser(
+        description = "𦄠Fine-tune your llm faster using unsloth!"
+    )
 
     model_group = parser.add_argument_group("š¤ Model Options")
-    model_group.add_argument('--model_name', type=str, default="unsloth/llama-3-8b", help="Model name to load")
-    model_group.add_argument('--max_seq_length', type=int, default=2048, help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!")
-    model_group.add_argument('--dtype', type=str, default=None, help="Data type for model (None for auto detection)")
-    model_group.add_argument('--load_in_4bit', action='store_true', help="Use 4bit quantization to reduce memory usage")
-    model_group.add_argument('--dataset', type=str, default="yahma/alpaca-cleaned", help="Huggingface dataset to use for training")
-
-    lora_group = parser.add_argument_group("š§  LoRA Options", "These options are used to configure the LoRA model.")
-    lora_group.add_argument('--r', type=int, default=16, help="Rank for Lora model, default is 16.  (common values: 8, 16, 32, 64, 128)")
-    lora_group.add_argument('--lora_alpha', type=int, default=16, help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)")
-    lora_group.add_argument('--lora_dropout', type=float, default=0.0, help="LoRA dropout rate, default is 0.0 which is optimized.")
-    lora_group.add_argument('--bias', type=str, default="none", help="Bias setting for LoRA")
-    lora_group.add_argument('--use_gradient_checkpointing', type=str, default="unsloth", help="Use gradient checkpointing")
-    lora_group.add_argument('--random_state', type=int, default=3407, help="Random state for reproducibility, default is 3407.")
-    lora_group.add_argument('--use_rslora', action='store_true', help="Use rank stabilized LoRA")
-    lora_group.add_argument('--loftq_config', type=str, default=None, help="Configuration for LoftQ")
-
-   
+    model_group.add_argument(
+        "--model_name",
+        type = str,
+        default = "unsloth/llama-3-8b",
+        help = "Model name to load",
+    )
+    model_group.add_argument(
+        "--max_seq_length",
+        type = int,
+        default = 2048,
+        help = "Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!",
+    )
+    model_group.add_argument(
+        "--dtype",
+        type = str,
+        default = None,
+        help = "Data type for model (None for auto detection)",
+    )
+    model_group.add_argument(
+        "--load_in_4bit",
+        action = "store_true",
+        help = "Use 4bit quantization to reduce memory usage",
+    )
+    model_group.add_argument(
+        "--dataset",
+        type = str,
+        default = "yahma/alpaca-cleaned",
+        help = "Huggingface dataset to use for training",
+    )
+
+    lora_group = parser.add_argument_group(
+        "š§  LoRA Options", "These options are used to configure the LoRA model."
+    )
+    lora_group.add_argument(
+        "--r",
+        type = int,
+        default = 16,
+        help = "Rank for Lora model, default is 16.  (common values: 8, 16, 32, 64, 128)",
+    )
+    lora_group.add_argument(
+        "--lora_alpha",
+        type = int,
+        default = 16,
+        help = "LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)",
+    )
+    lora_group.add_argument(
+        "--lora_dropout",
+        type = float,
+        default = 0.0,
+        help = "LoRA dropout rate, default is 0.0 which is optimized.",
+    )
+    lora_group.add_argument(
+        "--bias", type = str, default = "none", help = "Bias setting for LoRA"
+    )
+    lora_group.add_argument(
+        "--use_gradient_checkpointing",
+        type = str,
+        default = "unsloth",
+        help = "Use gradient checkpointing",
+    )
+    lora_group.add_argument(
+        "--random_state",
+        type = int,
+        default = 3407,
+        help = "Random state for reproducibility, default is 3407.",
+    )
+    lora_group.add_argument(
+        "--use_rslora", action = "store_true", help = "Use rank stabilized LoRA"
+    )
+    lora_group.add_argument(
+        "--loftq_config", type = str, default = None, help = "Configuration for LoftQ"
+    )
+
     training_group = parser.add_argument_group("š Training Options")
-    training_group.add_argument('--per_device_train_batch_size', type=int, default=2, help="Batch size per device during training, default is 2.")
-    training_group.add_argument('--gradient_accumulation_steps', type=int, default=4, help="Number of gradient accumulation steps, default is 4.")
-    training_group.add_argument('--warmup_steps', type=int, default=5, help="Number of warmup steps, default is 5.")
-    training_group.add_argument('--max_steps', type=int, default=400, help="Maximum number of training steps.")
-    training_group.add_argument('--learning_rate', type=float, default=2e-4, help="Learning rate, default is 2e-4.")
-    training_group.add_argument('--optim', type=str, default="adamw_8bit", help="Optimizer type.")
-    training_group.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay, default is 0.01.")
-    training_group.add_argument('--lr_scheduler_type', type=str, default="linear", help="Learning rate scheduler type, default is 'linear'.")
-    training_group.add_argument('--seed', type=int, default=3407, help="Seed for reproducibility, default is 3407.")
-    
+    training_group.add_argument(
+        "--per_device_train_batch_size",
+        type = int,
+        default = 2,
+        help = "Batch size per device during training, default is 2.",
+    )
+    training_group.add_argument(
+        "--gradient_accumulation_steps",
+        type = int,
+        default = 4,
+        help = "Number of gradient accumulation steps, default is 4.",
+    )
+    training_group.add_argument(
+        "--warmup_steps",
+        type = int,
+        default = 5,
+        help = "Number of warmup steps, default is 5.",
+    )
+    training_group.add_argument(
+        "--max_steps", type = int, default = 400, help = "Maximum number of training steps."
+    )
+    training_group.add_argument(
+        "--learning_rate",
+        type = float,
+        default = 2e-4,
+        help = "Learning rate, default is 2e-4.",
+    )
+    training_group.add_argument(
+        "--optim", type = str, default = "adamw_8bit", help = "Optimizer type."
+    )
+    training_group.add_argument(
+        "--weight_decay",
+        type = float,
+        default = 0.01,
+        help = "Weight decay, default is 0.01.",
+    )
+    training_group.add_argument(
+        "--lr_scheduler_type",
+        type = str,
+        default = "linear",
+        help = "Learning rate scheduler type, default is 'linear'.",
+    )
+    training_group.add_argument(
+        "--seed",
+        type = int,
+        default = 3407,
+        help = "Seed for reproducibility, default is 3407.",
+    )
 
     # Report/Logging arguments
     report_group = parser.add_argument_group("š Report Options")
-    report_group.add_argument('--report_to', type=str, default="tensorboard",
-        choices=["azure_ml", "clearml", "codecarbon", "comet_ml", "dagshub", "dvclive", "flyte", "mlflow", "neptune", "tensorboard", "wandb", "all", "none"],
-        help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.")
-    report_group.add_argument('--logging_steps', type=int, default=1, help="Logging steps, default is 1")
+    report_group.add_argument(
+        "--report_to",
+        type = str,
+        default = "tensorboard",
+        choices = [
+            "azure_ml",
+            "clearml",
+            "codecarbon",
+            "comet_ml",
+            "dagshub",
+            "dvclive",
+            "flyte",
+            "mlflow",
+            "neptune",
+            "tensorboard",
+            "wandb",
+            "all",
+            "none",
+        ],
+        help = "The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.",
+    )
+    report_group.add_argument(
+        "--logging_steps", type = int, default = 1, help = "Logging steps, default is 1"
+    )
 
     # Saving and pushing arguments
-    save_group = parser.add_argument_group('š¾ Save Model Options')
-    save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory")
-    save_group.add_argument('--save_model', action='store_true', help="Save the model after training")
-    save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'")
-    save_group.add_argument('--save_gguf', action='store_true', help="Convert the model to GGUF after training")
-    save_group.add_argument('--save_path', type=str, default="model", help="Path to save the model")
-    save_group.add_argument('--quantization', type=str, default="q8_0", nargs="+",
-        help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ")
-
-    push_group = parser.add_argument_group('š Push Model Options')
-    push_group.add_argument('--push_model', action='store_true', help="Push the model to Hugging Face hub after training")
-    push_group.add_argument('--push_gguf', action='store_true', help="Push the model as GGUF to Hugging Face hub after training")
-    push_group.add_argument('--hub_path', type=str, default="hf/model", help="Path on Hugging Face hub to push the model")
-    push_group.add_argument('--hub_token', type=str, help="Token for pushing the model to Hugging Face hub")
+    save_group = parser.add_argument_group("š¾ Save Model Options")
+    save_group.add_argument(
+        "--output_dir", type = str, default = "outputs", help = "Output directory"
+    )
+    save_group.add_argument(
+        "--save_model", action = "store_true", help = "Save the model after training"
+    )
+    save_group.add_argument(
+        "--save_method",
+        type = str,
+        default = "merged_16bit",
+        choices = ["merged_16bit", "merged_4bit", "lora"],
+        help = "Save method for the model, default is 'merged_16bit'",
+    )
+    save_group.add_argument(
+        "--save_gguf",
+        action = "store_true",
+        help = "Convert the model to GGUF after training",
+    )
+    save_group.add_argument(
+        "--save_path", type = str, default = "model", help = "Path to save the model"
+    )
+    save_group.add_argument(
+        "--quantization",
+        type = str,
+        default = "q8_0",
+        nargs = "+",
+        help = "Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ",
+    )
+
+    push_group = parser.add_argument_group("š Push Model Options")
+    push_group.add_argument(
+        "--push_model",
+        action = "store_true",
+        help = "Push the model to Hugging Face hub after training",
+    )
+    push_group.add_argument(
+        "--push_gguf",
+        action = "store_true",
+        help = "Push the model as GGUF to Hugging Face hub after training",
+    )
+    push_group.add_argument(
+        "--hub_path",
+        type = str,
+        default = "hf/model",
+        help = "Path on Hugging Face hub to push the model",
+    )
+    push_group.add_argument(
+        "--hub_token", type = str, help = "Token for pushing the model to Hugging Face hub"
+    )
 
     args = parser.parse_args()
     run(args)
diff --git a/unsloth/__init__.py b/unsloth/__init__.py
index b48bc42db..cacb1087f 100644
--- a/unsloth/__init__.py
+++ b/unsloth/__init__.py
@@ -19,10 +19,11 @@
 
 # Fix some issues before importing other packages
 from .import_fixes import fix_message_factory_issue
-fix_message_factory_issue(); del fix_message_factory_issue;
 
+fix_message_factory_issue()
+del fix_message_factory_issue
 # Check if modules that need patching are already imported
-critical_modules = ['trl', 'transformers', 'peft']
+critical_modules = ["trl", "transformers", "peft"]
 already_imported = [mod for mod in critical_modules if mod in sys.modules]
 
 # This check is critical because Unsloth optimizes these libraries by modifying
@@ -40,7 +41,6 @@
         f"Please restructure your imports with 'import unsloth' at the top of your file.",
         stacklevel = 2,
     )
-pass
 
 # Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
 # enabling it will require much more work, so we have to prioritize. Please understand!
@@ -62,22 +62,22 @@
     import torch
 except ModuleNotFoundError:
     raise ImportError(
-        "Unsloth: Pytorch is not installed. Go to https://pytorch.org/.\n"\
+        "Unsloth: Pytorch is not installed. Go to https://pytorch.org/.\n"
         "We have some installation instructions on our Github page."
     )
 except Exception as exception:
     raise exception
-pass
 
 import importlib.util
 from pathlib import Path
 from importlib.metadata import version as importlib_version
+
 # Check for unsloth_zoo
 try:
     unsloth_zoo_version = importlib_version("unsloth_zoo")
     if Version(unsloth_zoo_version) < Version("2025.10.12"):
         print(
-            "Unsloth: Please update Unsloth and Unsloth-Zoo to the latest version!\n"\
+            "Unsloth: Please update Unsloth and Unsloth-Zoo to the latest version!\n"
             "Do this via `pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo`"
         )
         # if os.environ.get("UNSLOTH_DISABLE_AUTO_UPDATES", "0") == "0":
@@ -92,8 +92,9 @@
 except NotImplementedError as e:
     raise NotImplementedError(str(e))
 except Exception as e:
-    raise ImportError(f"Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo` Also error = {str(e)}")
-pass
+    raise ImportError(
+        f"Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo` Also error = {str(e)}"
+    )
 
 from unsloth_zoo.device_type import (
     is_hip,
@@ -106,66 +107,86 @@
 
 # Fix other issues
 from .import_fixes import fix_xformers_performance_issue
-fix_xformers_performance_issue(); del fix_xformers_performance_issue;
+
+fix_xformers_performance_issue()
+del fix_xformers_performance_issue
 from .import_fixes import fix_vllm_aimv2_issue
-fix_vllm_aimv2_issue(); del fix_vllm_aimv2_issue;
+
+fix_vllm_aimv2_issue()
+del fix_vllm_aimv2_issue
 from .import_fixes import ignore_logger_messages
-ignore_logger_messages(); del ignore_logger_messages;
+
+ignore_logger_messages()
+del ignore_logger_messages
 from .import_fixes import patch_ipykernel_hf_xet
-patch_ipykernel_hf_xet(); del patch_ipykernel_hf_xet;
+
+patch_ipykernel_hf_xet()
+del patch_ipykernel_hf_xet
 from .import_fixes import patch_trackio
-patch_trackio(); del patch_trackio;
 
+patch_trackio()
+del patch_trackio
 # Torch 2.4 has including_emulation
 if DEVICE_TYPE == "cuda":
     major_version, minor_version = torch.cuda.get_device_capability()
-    SUPPORTS_BFLOAT16 = (major_version >= 8)
+    SUPPORTS_BFLOAT16 = major_version >= 8
 
     old_is_bf16_supported = torch.cuda.is_bf16_supported
     if "including_emulation" in str(inspect.signature(old_is_bf16_supported)):
+
         def is_bf16_supported(including_emulation = False):
             return old_is_bf16_supported(including_emulation)
+
         torch.cuda.is_bf16_supported = is_bf16_supported
     else:
-        def is_bf16_supported(): return SUPPORTS_BFLOAT16
+
+        def is_bf16_supported():
+            return SUPPORTS_BFLOAT16
+
         torch.cuda.is_bf16_supported = is_bf16_supported
-    pass
 elif DEVICE_TYPE == "hip":
     SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()
 elif DEVICE_TYPE == "xpu":
     # torch.xpu.is_bf16_supported() does not have including_emulation
     # set SUPPORTS_BFLOAT16 as torch.xpu.is_bf16_supported()
     SUPPORTS_BFLOAT16 = torch.xpu.is_bf16_supported()
-pass
 
 # For Gradio HF Spaces?
 # if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
 import triton
+
 if DEVICE_TYPE == "cuda":
     libcuda_dirs = lambda: None
     if Version(triton.__version__) >= Version("3.0.0"):
-        try: from triton.backends.nvidia.driver import libcuda_dirs
-        except: pass
-    else: from triton.common.build import libcuda_dirs
+        try:
+            from triton.backends.nvidia.driver import libcuda_dirs
+        except:
+            pass
+    else:
+        from triton.common.build import libcuda_dirs
 
     # Try loading bitsandbytes and triton
     try:
         import bitsandbytes as bnb
     except:
-        print("Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works!")
+        print(
+            "Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works!"
+        )
     try:
         cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
         libcuda_dirs()
     except:
-        warnings.warn(
-            "Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
-        )
+        warnings.warn("Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA.")
 
         if os.path.exists("/usr/lib64-nvidia"):
             os.system("ldconfig /usr/lib64-nvidia")
         elif os.path.exists("/usr/local"):
             # Sometimes bitsandbytes cannot be linked properly in Runpod for example
-            possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n")
+            possible_cudas = (
+                subprocess.check_output(["ls", "-al", "/usr/local"])
+                .decode("utf-8")
+                .split("\n")
+            )
             find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
             possible_cudas = [find_cuda.search(x) for x in possible_cudas]
             possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
@@ -175,38 +196,41 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
                 os.system("ldconfig /usr/local/")
             else:
                 find_number = re.compile(r"([\d\.]{2,})")
-                latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
+                latest_cuda = np.argsort(
+                    [float(find_number.search(x).group(1)) for x in possible_cudas]
+                )[::-1][0]
                 latest_cuda = possible_cudas[latest_cuda]
                 os.system(f"ldconfig /usr/local/{latest_cuda}")
-        pass
 
         importlib.reload(bnb)
         importlib.reload(triton)
         try:
             libcuda_dirs = lambda: None
             if Version(triton.__version__) >= Version("3.0.0"):
-                try: from triton.backends.nvidia.driver import libcuda_dirs
-                except: pass
-            else: from triton.common.build import libcuda_dirs
+                try:
+                    from triton.backends.nvidia.driver import libcuda_dirs
+                except:
+                    pass
+            else:
+                from triton.common.build import libcuda_dirs
             cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
             libcuda_dirs()
         except:
             warnings.warn(
-                "Unsloth: CUDA is not linked properly.\n"\
-                "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
-                "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
-                "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
-                "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
+                "Unsloth: CUDA is not linked properly.\n"
+                "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"
+                "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"
+                "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"
+                "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"
                 "Unsloth will still run for now, but maybe it might crash - let's hope it works!"
             )
-    pass
 elif DEVICE_TYPE == "hip":
     # NO-OP for rocm device
     pass
 elif DEVICE_TYPE == "xpu":
     import bitsandbytes as bnb
+
     # TODO: check triton for intel installed properly.
-    pass
 
 from .models import *
 from .models import __version__
diff --git a/unsloth/_auto_install.py b/unsloth/_auto_install.py
index f6cf6cfc9..5121c7ea7 100644
--- a/unsloth/_auto_install.py
+++ b/unsloth/_auto_install.py
@@ -12,29 +12,50 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-try: import torch
-except: raise ImportError('Install torch via `pip install torch`')
+try:
+    import torch
+except:
+    raise ImportError("Install torch via `pip install torch`")
 from packaging.version import Version as V
 import re
+
 v = V(re.match(r"[0-9\.]{3,}", torch.__version__).group(0))
 cuda = str(torch.version.cuda)
 is_ampere = torch.cuda.get_device_capability()[0] >= 8
 USE_ABI = torch._C._GLIBCXX_USE_CXX11_ABI
-if cuda not in ("11.8", "12.1", "12.4", "12.6", "12.8", "13.0"): raise RuntimeError(f"CUDA = {cuda} not supported!")
-if   v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
-elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
-elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
-elif v  < V('2.3.0'): x = 'cu{}{}-torch220'
-elif v  < V('2.4.0'): x = 'cu{}{}-torch230'
-elif v  < V('2.5.0'): x = 'cu{}{}-torch240'
-elif v  < V('2.5.1'): x = 'cu{}{}-torch250'
-elif v <= V('2.5.1'): x = 'cu{}{}-torch251'
-elif v  < V('2.7.0'): x = 'cu{}{}-torch260'
-elif v  < V('2.7.9'): x = 'cu{}{}-torch270'
-elif v  < V('2.8.0'): x = 'cu{}{}-torch271'
-elif v  < V('2.8.9'): x = 'cu{}{}-torch280'
-elif v  < V('2.9.1'): x = 'cu{}{}-torch290'
-else: raise RuntimeError(f"Torch = {v} too new!")
-if v > V('2.6.9') and cuda not in ("11.8", "12.6", "12.8", "13.0"): raise RuntimeError(f"CUDA = {cuda} not supported!")
+if cuda not in ("11.8", "12.1", "12.4", "12.6", "12.8", "13.0"):
+    raise RuntimeError(f"CUDA = {cuda} not supported!")
+if v <= V("2.1.0"):
+    raise RuntimeError(f"Torch = {v} too old!")
+elif v <= V("2.1.1"):
+    x = "cu{}{}-torch211"
+elif v <= V("2.1.2"):
+    x = "cu{}{}-torch212"
+elif v < V("2.3.0"):
+    x = "cu{}{}-torch220"
+elif v < V("2.4.0"):
+    x = "cu{}{}-torch230"
+elif v < V("2.5.0"):
+    x = "cu{}{}-torch240"
+elif v < V("2.5.1"):
+    x = "cu{}{}-torch250"
+elif v <= V("2.5.1"):
+    x = "cu{}{}-torch251"
+elif v < V("2.7.0"):
+    x = "cu{}{}-torch260"
+elif v < V("2.7.9"):
+    x = "cu{}{}-torch270"
+elif v < V("2.8.0"):
+    x = "cu{}{}-torch271"
+elif v < V("2.8.9"):
+    x = "cu{}{}-torch280"
+elif v < V("2.9.1"):
+    x = "cu{}{}-torch290"
+else:
+    raise RuntimeError(f"Torch = {v} too new!")
+if v > V("2.6.9") and cuda not in ("11.8", "12.6", "12.8", "13.0"):
+    raise RuntimeError(f"CUDA = {cuda} not supported!")
 x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
-print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
\ No newline at end of file
+print(
+    f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
+)
diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py
index 7962c27f9..9393c6b7c 100644
--- a/unsloth/chat_templates.py
+++ b/unsloth/chat_templates.py
@@ -17,13 +17,11 @@
     "test_chat_templates",
     "test_hf_gguf_equivalence",
     "remove_special_tokens",
-
     "to_sharegpt",
     "standardize_sharegpt",
     "standardize_data_formats",
     "apply_chat_template",
     "train_on_responses_only",
-
     "test_construct_chat_template",
 ]
 
@@ -40,37 +38,37 @@
     train_on_responses_only,
     standardize_data_formats,
 )
+
 standardize_sharegpt = standardize_data_formats
 CHAT_TEMPLATES = {}
 DEFAULT_SYSTEM_MESSAGE = {}
 
 # =========================================== Unsloth
 # Unsloth efficient template leverages from Zephyr
-unsloth_template = \
-    "{{ bos_token }}"\
-    "{% if messages[0]['role'] == 'system' %}"\
-        "{{ messages[0]['content'] + '\n' }}"\
-        "{% set loop_messages = messages[1:] %}"\
-    "{% else %}"\
-        "{{ '{system_message}' + '\n' }}"\
-        "{% set loop_messages = messages %}"\
-    "{% endif %}"\
-    "{% for message in loop_messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{ '>>> User: ' + message['content'] + '\n' }}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{ '>>> Assistant: ' + message['content'] + eos_token + '\n' }}"\
-        "{% else %}"\
-            "{{ raise_exception('Only user and assistant roles are supported!') }}"\
-        "{% endif %}"\
-    "{% endfor %}"\
-    "{% if add_generation_prompt %}"\
-        "{{ '>>> Assistant: ' }}"\
+unsloth_template = (
+    "{{ bos_token }}"
+    "{% if messages[0]['role'] == 'system' %}"
+    "{{ messages[0]['content'] + '\n' }}"
+    "{% set loop_messages = messages[1:] %}"
+    "{% else %}"
+    "{{ '{system_message}' + '\n' }}"
+    "{% set loop_messages = messages %}"
+    "{% endif %}"
+    "{% for message in loop_messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{ '>>> User: ' + message['content'] + '\n' }}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{ '>>> Assistant: ' + message['content'] + eos_token + '\n' }}"
+    "{% else %}"
+    "{{ raise_exception('Only user and assistant roles are supported!') }}"
     "{% endif %}"
-pass
+    "{% endfor %}"
+    "{% if add_generation_prompt %}"
+    "{{ '>>> Assistant: ' }}"
+    "{% endif %}"
+)
 
-unsloth_ollama = \
-'''
+unsloth_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}{{ .System }}
 {{ end }}{{ if .Prompt }}>>> User: {{ .Prompt }}
@@ -83,29 +81,32 @@
 '''
 
 unsloth_eos_token = "eos_token"
-CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)
+CHAT_TEMPLATES["unsloth"] = (
+    unsloth_template,
+    unsloth_eos_token,
+    False,
+    unsloth_ollama,
+)
 DEFAULT_SYSTEM_MESSAGE["unsloth"] = "You are a helpful assistant to the user"
-pass
 
 # =========================================== Zephyr
 # Zephyr has no BOS!
-zephyr_template = \
-    "{% for message in messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"\
-        "{% else %}"\
-            "{{ '<|system|>\n' + message['content'] + eos_token + '\n' }}"\
-        "{% endif %}"\
-    "{% endfor %}"\
-    "{% if add_generation_prompt %}"\
-        "{{ '<|assistant|>\n' }}"\
+zephyr_template = (
+    "{% for message in messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"
+    "{% else %}"
+    "{{ '<|system|>\n' + message['content'] + eos_token + '\n' }}"
+    "{% endif %}"
+    "{% endfor %}"
+    "{% if add_generation_prompt %}"
+    "{{ '<|assistant|>\n' }}"
     "{% endif %}"
-pass
+)
 
-zephyr_ollama = \
-'''
+zephyr_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|system|>
 {{ .System }}{__EOS_TOKEN__}
@@ -120,29 +121,32 @@
 '''
 
 zephyr_eos_token = "eos_token"
-CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)
-DEFAULT_SYSTEM_MESSAGE["zephyr"] = None # No system message in Zephyr
-pass
+CHAT_TEMPLATES["zephyr"] = (
+    zephyr_template,
+    zephyr_eos_token,
+    False,
+    zephyr_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["zephyr"] = None  # No system message in Zephyr
 
 # =========================================== ChatML
 # ChatML has no BOS and not EOS! Rather <|im_start|> and <|im_end|> acts as BOS / EOS.
-chatml_template = \
-    "{% for message in messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n'}}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>\n' }}"\
-        "{% else %}"\
-            "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}"\
-        "{% endif %}"\
-    "{% endfor %}"\
-    "{% if add_generation_prompt %}"\
-        "{{ '<|im_start|>assistant\n' }}"\
+chatml_template = (
+    "{% for message in messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n'}}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>\n' }}"
+    "{% else %}"
+    "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}"
+    "{% endif %}"
+    "{% endfor %}"
+    "{% if add_generation_prompt %}"
+    "{{ '<|im_start|>assistant\n' }}"
     "{% endif %}"
-pass
+)
 
-chatml_ollama = \
-'''
+chatml_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|im_start|>system
 {{ .System }}<|im_end|>
@@ -158,39 +162,42 @@
 '''
 
 chatml_eos_token = "<|im_end|>"
-CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)
-DEFAULT_SYSTEM_MESSAGE["chatml"] = None # No system message in ChatML
-pass
+CHAT_TEMPLATES["chatml"] = (
+    chatml_template,
+    chatml_eos_token,
+    True,
+    chatml_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["chatml"] = None  # No system message in ChatML
 
 # =========================================== Mistral-1
 # Mistral Instruct doesn't allow system prompts, so we append it to the user message.
-mistral_template = \
-    "{{ bos_token }}"\
-    "{% if messages[0]['role'] == 'system' %}"\
-        "{% if messages[1]['role'] == 'user' %}"\
-            "{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
-            "{% set loop_messages = messages[2:] %}"\
-        "{% else %}"\
-            "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
-            "{% set loop_messages = messages[1:] %}"\
-        "{% endif %}"\
-    "{% else %}"\
-        "{% set loop_messages = messages %}"\
-    "{% endif %}"\
-    "{% for message in loop_messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{ message['content'] + eos_token }}"\
-        "{% else %}"\
-            "{{ raise_exception('Only user and assistant roles are supported!') }}"\
-        "{% endif %}"\
+mistral_template = (
+    "{{ bos_token }}"
+    "{% if messages[0]['role'] == 'system' %}"
+    "{% if messages[1]['role'] == 'user' %}"
+    "{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"
+    "{% set loop_messages = messages[2:] %}"
+    "{% else %}"
+    "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"
+    "{% set loop_messages = messages[1:] %}"
+    "{% endif %}"
+    "{% else %}"
+    "{% set loop_messages = messages %}"
+    "{% endif %}"
+    "{% for message in loop_messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{ '[INST] ' + message['content'] + ' [/INST]' }}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{ message['content'] + eos_token }}"
+    "{% else %}"
+    "{{ raise_exception('Only user and assistant roles are supported!') }}"
+    "{% endif %}"
     "{% endfor %}"
-pass
+)
 
 # Ollama from https://www.ollama.com/library/mistral
-mistral_ollama = \
-'''
+mistral_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]"""
 PARAMETER stop "{__EOS_TOKEN__}"
@@ -199,38 +206,41 @@
 '''
 
 mistral_eos_token = "eos_token"
-CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)
-DEFAULT_SYSTEM_MESSAGE["mistral"] = None # No system message in Mistral
-pass
+CHAT_TEMPLATES["mistral"] = (
+    mistral_template,
+    mistral_eos_token,
+    False,
+    mistral_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["mistral"] = None  # No system message in Mistral
 
 # =========================================== Llama-2
 # Adds BOS to every convo! And weird <> system messages.
-llama_template = \
-    "{% if messages[0]['role'] == 'system' %}"\
-        "{% if messages[1]['role'] == 'user' %}"\
-            "{{ bos_token + '[INST] <>\n' + messages[0]['content'] + '\n<>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
-            "{% set loop_messages = messages[2:] %}"\
-        "{% else %}"\
-            "{{ bos_token + '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
-            "{% set loop_messages = messages[1:] %}"\
-        "{% endif %}"\
-    "{% else %}"\
-        "{% set loop_messages = messages %}"\
-    "{% endif %}"\
-    "{% for message in loop_messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{ ' ' + message['content'].strip() + ' ' + eos_token }}"\
-        "{% else %}"\
-            "{{ raise_exception('Only user and assistant roles are supported!') }}"\
-        "{% endif %}"\
+llama_template = (
+    "{% if messages[0]['role'] == 'system' %}"
+    "{% if messages[1]['role'] == 'user' %}"
+    "{{ bos_token + '[INST] <>\n' + messages[0]['content'] + '\n<>\n\n' + messages[1]['content'] + ' [/INST]' }}"
+    "{% set loop_messages = messages[2:] %}"
+    "{% else %}"
+    "{{ bos_token + '[INST] ' + messages[0]['content'] + ' [/INST]' }}"
+    "{% set loop_messages = messages[1:] %}"
+    "{% endif %}"
+    "{% else %}"
+    "{% set loop_messages = messages %}"
+    "{% endif %}"
+    "{% for message in loop_messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{ ' ' + message['content'].strip() + ' ' + eos_token }}"
+    "{% else %}"
+    "{{ raise_exception('Only user and assistant roles are supported!') }}"
+    "{% endif %}"
     "{% endfor %}"
-pass
+)
 
 # Ollama from https://www.ollama.com/library/llama3
-llama_ollama = \
-'''
+llama_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """[INST] <>{{ .System }}<>
 
@@ -241,38 +251,41 @@
 '''
 
 llama_eos_token = "eos_token"
-CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,)
-DEFAULT_SYSTEM_MESSAGE["llama"] = None # No system message in Llama
-pass
+CHAT_TEMPLATES["llama"] = (
+    llama_template,
+    llama_eos_token,
+    False,
+    llama_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["llama"] = None  # No system message in Llama
 
 # ===========================================  Vicuna
 # https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
-vicuna_template = \
-    "{{ bos_token }}"\
-    "{% if messages[0]['role'] == 'system' %}"\
-        "{{ messages[0]['content'] + ' ' }}"\
-        "{% set loop_messages = messages[1:] %}"\
-    "{% else %}"\
-        "{{ '{system_message}' + ' ' }}"\
-        "{% set loop_messages = messages %}"\
-    "{% endif %}"\
-    "{% for message in loop_messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{ 'USER: ' + message['content'] + ' ' }}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{ 'ASSISTANT: ' + message['content'] + eos_token }}"\
-        "{% else %}"\
-            "{{ raise_exception('Only user and assistant roles are supported!') }}"\
-        "{% endif %}"\
-    "{% endfor %}"\
-    "{% if add_generation_prompt %}"\
-        "{{ 'ASSISTANT:' }}"\
+vicuna_template = (
+    "{{ bos_token }}"
+    "{% if messages[0]['role'] == 'system' %}"
+    "{{ messages[0]['content'] + ' ' }}"
+    "{% set loop_messages = messages[1:] %}"
+    "{% else %}"
+    "{{ '{system_message}' + ' ' }}"
+    "{% set loop_messages = messages %}"
+    "{% endif %}"
+    "{% for message in loop_messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{ 'USER: ' + message['content'] + ' ' }}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{ 'ASSISTANT: ' + message['content'] + eos_token }}"
+    "{% else %}"
+    "{{ raise_exception('Only user and assistant roles are supported!') }}"
+    "{% endif %}"
+    "{% endfor %}"
+    "{% if add_generation_prompt %}"
+    "{{ 'ASSISTANT:' }}"
     "{% endif %}"
-pass
+)
 
 # Ollama from https://www.ollama.com/library/vicuna
-vicuna_ollama = \
-'''
+vicuna_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} {__EOS_TOKEN__}"""
 PARAMETER stop "{__EOS_TOKEN__}"
@@ -281,37 +294,42 @@
 '''
 
 vicuna_eos_token = "eos_token"
-CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)
-DEFAULT_SYSTEM_MESSAGE["vicuna"] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
-pass
+CHAT_TEMPLATES["vicuna"] = (
+    vicuna_template,
+    vicuna_eos_token,
+    False,
+    vicuna_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["vicuna"] = (
+    "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
+)
 
 # =========================================== Vicuna Old
 # https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
-vicuna_old_template = \
-    "{{ bos_token }}"\
-    "{% if messages[0]['role'] == 'system' %}"\
-        "{{ messages[0]['content'] + '\n' }}"\
-        "{% set loop_messages = messages[1:] %}"\
-    "{% else %}"\
-        "{{ '{system_message}' + '\n' }}"\
-        "{% set loop_messages = messages %}"\
-    "{% endif %}"\
-    "{% for message in loop_messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{ '### Human: ' + message['content'] + '\n' }}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{ '### Assistant: ' + message['content'] + eos_token + '\n' }}"\
-        "{% else %}"\
-            "{{ raise_exception('Only user and assistant roles are supported!') }}"\
-        "{% endif %}"\
-    "{% endfor %}"\
-    "{% if add_generation_prompt %}"\
-        "{{ '### Assistant:' }}"\
+vicuna_old_template = (
+    "{{ bos_token }}"
+    "{% if messages[0]['role'] == 'system' %}"
+    "{{ messages[0]['content'] + '\n' }}"
+    "{% set loop_messages = messages[1:] %}"
+    "{% else %}"
+    "{{ '{system_message}' + '\n' }}"
+    "{% set loop_messages = messages %}"
     "{% endif %}"
-pass
+    "{% for message in loop_messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{ '### Human: ' + message['content'] + '\n' }}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{ '### Assistant: ' + message['content'] + eos_token + '\n' }}"
+    "{% else %}"
+    "{{ raise_exception('Only user and assistant roles are supported!') }}"
+    "{% endif %}"
+    "{% endfor %}"
+    "{% if add_generation_prompt %}"
+    "{{ '### Assistant:' }}"
+    "{% endif %}"
+)
 
-vicuna_old_ollama = \
-'''
+vicuna_old_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}{{ .System }}
 {{ end }}{{ if .Prompt }}### Human: {{ .Prompt }}
@@ -324,40 +342,45 @@
 '''
 
 vicuna_old_eos_token = "eos_token"
-CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)
-DEFAULT_SYSTEM_MESSAGE["vicuna_old"] = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions."
+CHAT_TEMPLATES["vicuna_old"] = (
+    vicuna_old_template,
+    vicuna_old_eos_token,
+    False,
+    vicuna_old_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["vicuna_old"] = (
+    "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions."
+)
 
 CHAT_TEMPLATES["vicuna old"] = CHAT_TEMPLATES["vicuna_old"]
 DEFAULT_SYSTEM_MESSAGE["vicuna old"] = DEFAULT_SYSTEM_MESSAGE["vicuna_old"]
-pass
 
 # =========================================== Alpaca multi turn
 # https://github.com/tatsu-lab/stanford_alpaca Changed for multi-turn convos
-alpaca_template = \
-    "{{ bos_token }}"\
-    "{% if messages[0]['role'] == 'system' %}"\
-        "{{ messages[0]['content'] + '\n\n' }}"\
-        "{% set loop_messages = messages[1:] %}"\
-    "{% else %}"\
-        "{{ '{system_message}' + '\n\n' }}"\
-        "{% set loop_messages = messages %}"\
-    "{% endif %}"\
-    "{% for message in loop_messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{ '### Instruction:\n' + message['content'] + '\n\n' }}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{ '### Response:\n' + message['content'] + eos_token + '\n\n' }}"\
-        "{% else %}"\
-            "{{ raise_exception('Only user and assistant roles are supported!') }}"\
-        "{% endif %}"\
-    "{% endfor %}"\
-    "{% if add_generation_prompt %}"\
-        "{{ '### Response:\n' }}"\
+alpaca_template = (
+    "{{ bos_token }}"
+    "{% if messages[0]['role'] == 'system' %}"
+    "{{ messages[0]['content'] + '\n\n' }}"
+    "{% set loop_messages = messages[1:] %}"
+    "{% else %}"
+    "{{ '{system_message}' + '\n\n' }}"
+    "{% set loop_messages = messages %}"
     "{% endif %}"
-pass
+    "{% for message in loop_messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{ '### Instruction:\n' + message['content'] + '\n\n' }}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{ '### Response:\n' + message['content'] + eos_token + '\n\n' }}"
+    "{% else %}"
+    "{{ raise_exception('Only user and assistant roles are supported!') }}"
+    "{% endif %}"
+    "{% endfor %}"
+    "{% if add_generation_prompt %}"
+    "{{ '### Response:\n' }}"
+    "{% endif %}"
+)
 
-alpaca_ollama = \
-'''
+alpaca_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}{{ .System }}
 
@@ -375,37 +398,42 @@
 '''
 
 alpaca_eos_token = "eos_token"
-CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)
-DEFAULT_SYSTEM_MESSAGE["alpaca"] = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
-pass
+CHAT_TEMPLATES["alpaca"] = (
+    alpaca_template,
+    alpaca_eos_token,
+    False,
+    alpaca_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["alpaca"] = (
+    "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
+)
 
 # =========================================== Gemma
 # https://huggingface.co/google/gemma-7b-it
 # Notice we must use |trim for lstrip and rstrip.  maps to 106.
 #  maps to 107. user and model are normal 1 word tokens.
-gemma_template = \
-    "{{ bos_token }}"\
-    "{% if messages[0]['role'] == 'system' %}"\
-        "{{'user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '\n'}}"\
-        "{% set messages = messages[2:] %}"\
-    "{% endif %}"\
-    "{% for message in messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{'user\n' + message['content'] | trim + '\n'}}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{'model\n' + message['content'] | trim + '\n' }}"\
-        "{% else %}"\
-            "{{ raise_exception('Only user and assistant roles are supported!') }}"\
-        "{% endif %}"\
-    "{% endfor %}"\
-    "{% if add_generation_prompt %}"\
-        "{{ 'model\n' }}"\
+gemma_template = (
+    "{{ bos_token }}"
+    "{% if messages[0]['role'] == 'system' %}"
+    "{{'user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '\n'}}"
+    "{% set messages = messages[2:] %}"
     "{% endif %}"
-pass
+    "{% for message in messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{'user\n' + message['content'] | trim + '\n'}}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{'model\n' + message['content'] | trim + '\n' }}"
+    "{% else %}"
+    "{{ raise_exception('Only user and assistant roles are supported!') }}"
+    "{% endif %}"
+    "{% endfor %}"
+    "{% if add_generation_prompt %}"
+    "{{ 'model\n' }}"
+    "{% endif %}"
+)
 
 # Ollama from https://www.ollama.com/library/gemma
-gemma_ollama = \
-'''
+gemma_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """user
 {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}
@@ -421,17 +449,19 @@
 '''
 
 gemma_eos_token = ""
-CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gemma"] = None # No system message in Gemma
-pass
+CHAT_TEMPLATES["gemma"] = (
+    gemma_template,
+    gemma_eos_token,
+    True,
+    gemma_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["gemma"] = None  # No system message in Gemma
 
 # =========================================== Gemma with ChatML instead
 # We find using  is still more appropriate!
 gemma_chatml_template = "{{ bos_token }}" + chatml_template
-pass
 
-gemma_chatml_ollama = \
-'''
+gemma_chatml_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|im_start|>system
 {{ .System }}<|im_end|>
@@ -449,12 +479,16 @@
 '''
 
 gemma_chatml_eos_token = (
-    {"" : "<|im_start|>", "" : "<|im_end|>"},
+    {"": "<|im_start|>", "": "<|im_end|>"},
     "<|im_end|>",
 )
-CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gemma_chatml"] = None # No system message in Gemma
-pass
+CHAT_TEMPLATES["gemma_chatml"] = (
+    gemma_chatml_template,
+    gemma_chatml_eos_token,
+    True,
+    gemma_chatml_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["gemma_chatml"] = None  # No system message in Gemma
 
 # =========================================== Gemma 2
 # Same as Gemma 1, but with sliding window attention!
@@ -462,38 +496,46 @@
 gemma2_template = gemma_template
 gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
 gemma2_eos_token = ""
-CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gemma2"] = None # No system message in Gemma 2
+CHAT_TEMPLATES["gemma2"] = (
+    gemma2_template,
+    gemma2_eos_token,
+    True,
+    gemma2_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["gemma2"] = None  # No system message in Gemma 2
 
 # =========================================== Gemma 2 with ChatML instead
 gemma2_chatml_template = gemma_chatml_template
 gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
 gemma2_chatml_eos_token = gemma_chatml_eos_token
-CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gemma2_chatml"] = None # No system message in Gemma 2
-pass
+CHAT_TEMPLATES["gemma2_chatml"] = (
+    gemma2_chatml_template,
+    gemma2_chatml_eos_token,
+    True,
+    gemma2_chatml_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["gemma2_chatml"] = None  # No system message in Gemma 2
 
 # =========================================== Llama-3
 # Weirdly \n\n is needed?
-llama3_template = \
-    "{{ bos_token }}"\
-    "{% for message in messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
-        "{% else %}"\
-            "{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
-        "{% endif %}"\
-    "{% endfor %}"\
-    "{% if add_generation_prompt %}"\
-        "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"\
+llama3_template = (
+    "{{ bos_token }}"
+    "{% for message in messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"
+    "{% else %}"
+    "{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"
+    "{% endif %}"
+    "{% endfor %}"
+    "{% if add_generation_prompt %}"
+    "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
     "{% endif %}"
-pass
+)
 
 # Ollama from https://www.ollama.com/library/llama3
-llama3_ollama = \
-'''
+llama3_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
 
@@ -511,34 +553,42 @@
 
 llama3_template_eos_token = "eos_token"
 
-CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
-DEFAULT_SYSTEM_MESSAGE["llama-3"] = None # No system message in Llama-3
+CHAT_TEMPLATES["llama-3"] = (
+    llama3_template,
+    llama3_template_eos_token,
+    False,
+    llama3_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["llama-3"] = None  # No system message in Llama-3
 
-CHAT_TEMPLATES["llama3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
-DEFAULT_SYSTEM_MESSAGE["llama3"] = None # No system message in Llama-3
-pass
+CHAT_TEMPLATES["llama3"] = (
+    llama3_template,
+    llama3_template_eos_token,
+    False,
+    llama3_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["llama3"] = None  # No system message in Llama-3
 
 
 # =========================================== Phi-3
 # "{{ bos_token }}"\ # Phi-3.5 removes BOS?
-phi3_template = \
-    "{% for message in messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{'<|user|>\n' + message['content'] + '<|end|>\n'}}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}"\
-        "{% else %}"\
-            "{{'<|' + message['role'] + '|>\n' + message['content'] + '<|end|>\n'}}"\
-        "{% endif %}"\
-    "{% endfor %}"\
-    "{% if add_generation_prompt %}"\
-        "{{ '<|assistant|>\n' }}"\
+phi3_template = (
+    "{% for message in messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{'<|user|>\n' + message['content'] + '<|end|>\n'}}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}"
+    "{% else %}"
+    "{{'<|' + message['role'] + '|>\n' + message['content'] + '<|end|>\n'}}"
+    "{% endif %}"
+    "{% endfor %}"
+    "{% if add_generation_prompt %}"
+    "{{ '<|assistant|>\n' }}"
     "{% endif %}"
-pass
+)
 
 # Ollama from https://www.ollama.com/library/phi3
-phi3_ollama = \
-'''
+phi3_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|system|>
 {{ .System }}<|end|>
@@ -555,15 +605,19 @@
 '''
 
 phi3_template_eos_token = "<|end|>"
-CHAT_TEMPLATES["phi-3"]   = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
-DEFAULT_SYSTEM_MESSAGE["phi-3"] = None # No system message in Phi-3
+CHAT_TEMPLATES["phi-3"] = (
+    phi3_template,
+    phi3_template_eos_token,
+    False,
+    phi3_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["phi-3"] = None  # No system message in Phi-3
 
-CHAT_TEMPLATES["phi-35"]  = CHAT_TEMPLATES["phi-3"]
-DEFAULT_SYSTEM_MESSAGE["phi-35"] = None # No system message in Phi-3.5
+CHAT_TEMPLATES["phi-35"] = CHAT_TEMPLATES["phi-3"]
+DEFAULT_SYSTEM_MESSAGE["phi-35"] = None  # No system message in Phi-3.5
 
 CHAT_TEMPLATES["phi-3.5"] = CHAT_TEMPLATES["phi-3"]
-DEFAULT_SYSTEM_MESSAGE["phi-3.5"] = None # No system message in Phi-3.5
-pass
+DEFAULT_SYSTEM_MESSAGE["phi-3.5"] = None  # No system message in Phi-3.5
 
 # =========================================== Llama-3.1
 """
@@ -582,8 +636,7 @@
 )
 """
 
-llama31_template = \
-"""{{- bos_token }}
+llama31_template = """{{- bos_token }}
 {%- if custom_tools is defined %}
     {%- set tools = custom_tools %}
 {%- endif %}
@@ -693,11 +746,9 @@
     {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
 {%- endif %}
 """
-pass
 
 # Ollama from https://ollama.com/library/llama3.1 (needs updating!)
-llama31_ollama = \
-'''
+llama31_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .Messages }}
 {{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
@@ -757,21 +808,33 @@
 '''
 
 llama31_template_eos_token = "eos_token"
-CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
-DEFAULT_SYSTEM_MESSAGE["llama-3.1"] = "" # Llama3.1 default system message is empty + the dates
+CHAT_TEMPLATES["llama-3.1"] = (
+    llama31_template,
+    llama31_template_eos_token,
+    False,
+    llama31_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["llama-3.1"] = (
+    ""  # Llama3.1 default system message is empty + the dates
+)
 
-CHAT_TEMPLATES["llama-31"]  = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
-DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates
+CHAT_TEMPLATES["llama-31"] = (
+    llama31_template,
+    llama31_template_eos_token,
+    False,
+    llama31_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["llama-31"] = (
+    ""  # Llama3.1 default system message is empty + the dates
+)
 
 for version in ("llama-3.2", "llama-3.3", "llama-32", "llama-33"):
     CHAT_TEMPLATES[version] = CHAT_TEMPLATES["llama-3.1"]
     DEFAULT_SYSTEM_MESSAGE[version] = ""
-pass
 
 
 # =========================================== Qwen 2.5
-qwen25_template = \
-"""{%- if tools %}
+qwen25_template = """{%- if tools %}
     {{- \'<|im_start|>system\\n\' }}
     {%- if messages[0][\'role\'] == \'system\' %}
         {{- messages[0][\'content\'] }}
@@ -823,8 +886,7 @@
 
 
 # Ollama from https://ollama.com/library/qwen2.5/blobs/eb4402837c78
-qwen25_ollama = \
-'''
+qwen25_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- if .Messages }}
 {{- if or .System .Tools }}<|im_start|>system
@@ -883,45 +945,74 @@
 '''
 
 qwen25_template_eos_token = "eos_token"
-qwen25_default_system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
-CHAT_TEMPLATES["qwen-2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
-DEFAULT_SYSTEM_MESSAGE["qwen-2.5"] = qwen25_default_system_message # No system message in Qwen 2.5
+qwen25_default_system_message = (
+    "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
+)
+CHAT_TEMPLATES["qwen-2.5"] = (
+    qwen25_template,
+    qwen25_template_eos_token,
+    False,
+    qwen25_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["qwen-2.5"] = (
+    qwen25_default_system_message  # No system message in Qwen 2.5
+)
 
-CHAT_TEMPLATES["qwen-25"]  = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
-DEFAULT_SYSTEM_MESSAGE["qwen-25"] = qwen25_default_system_message # No system message in Qwen 2.5
+CHAT_TEMPLATES["qwen-25"] = (
+    qwen25_template,
+    qwen25_template_eos_token,
+    False,
+    qwen25_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["qwen-25"] = (
+    qwen25_default_system_message  # No system message in Qwen 2.5
+)
 
-CHAT_TEMPLATES["qwen25"]   = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
-DEFAULT_SYSTEM_MESSAGE["qwen25"] = qwen25_default_system_message # No system message in Qwen 2.5
+CHAT_TEMPLATES["qwen25"] = (
+    qwen25_template,
+    qwen25_template_eos_token,
+    False,
+    qwen25_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["qwen25"] = (
+    qwen25_default_system_message  # No system message in Qwen 2.5
+)
 
-CHAT_TEMPLATES["qwen2.5"]  = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
-DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = qwen25_default_system_message # No system message in Qwen 2.5
-pass
+CHAT_TEMPLATES["qwen2.5"] = (
+    qwen25_template,
+    qwen25_template_eos_token,
+    False,
+    qwen25_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = (
+    qwen25_default_system_message  # No system message in Qwen 2.5
+)
 
 # =========================================== Phi-4
 # "{{ bos_token }}"\ # Phi-4 removes BOS?
-phi4_template = \
-    "{% for message in messages %}"\
-        "{% if (message['role'] == 'system') %}"\
-            "{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}"\
-        "{% elif (message['role'] == 'user') %}"\
-            "{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}"\
-        "{% elif (message['role'] == 'assistant') %}"\
-            "{{'<|im_start|>assistant<|im_sep|>' + message['content'] + '<|im_end|>'}}"\
-        "{% endif %}"\
-    "{% endfor %}"\
-    "{% if add_generation_prompt %}"\
-        "{{ '<|im_start|>assistant<|im_sep|>' }}"\
+phi4_template = (
+    "{% for message in messages %}"
+    "{% if (message['role'] == 'system') %}"
+    "{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}"
+    "{% elif (message['role'] == 'user') %}"
+    "{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}"
+    "{% elif (message['role'] == 'assistant') %}"
+    "{{'<|im_start|>assistant<|im_sep|>' + message['content'] + '<|im_end|>'}}"
     "{% endif %}"
-pass
+    "{% endfor %}"
+    "{% if add_generation_prompt %}"
+    "{{ '<|im_start|>assistant<|im_sep|>' }}"
+    "{% endif %}"
+)
 
-_phi4_ollama_template = \
-    "{{ if .System }}<|im_start|><|system|><|im_sep|>{{ .System }}<|im_end|>{{ end }}"\
-    "{{ if .Prompt }}<|im_start|><|user|><|im_sep|>{{ .Prompt }}<|im_end|>{{ end }}"\
+_phi4_ollama_template = (
+    "{{ if .System }}<|im_start|><|system|><|im_sep|>{{ .System }}<|im_end|>{{ end }}"
+    "{{ if .Prompt }}<|im_start|><|user|><|im_sep|>{{ .Prompt }}<|im_end|>{{ end }}"
     "<|im_start|><|assistant|><|im_sep|>{{ .Response }}<|im_end|>"
+)
 
 # Ollama from https://www.ollama.com/library/phi4 is different
-phi4_ollama = \
-f'''
+phi4_ollama = f'''
 FROM {{__FILE_LOCATION__}}
 TEMPLATE """{_phi4_ollama_template}"""
 PARAMETER stop "<|im_end|>"
@@ -932,16 +1023,19 @@
 '''
 
 phi4_template_eos_token = "<|im_end|>"
-CHAT_TEMPLATES["phi-4"] = (phi4_template, phi4_template_eos_token, False, phi4_ollama,)
-DEFAULT_SYSTEM_MESSAGE["phi-4"] = None # No system message in Phi-4
-pass
+CHAT_TEMPLATES["phi-4"] = (
+    phi4_template,
+    phi4_template_eos_token,
+    False,
+    phi4_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["phi-4"] = None  # No system message in Phi-4
 
 
 # =========================================== Gemma-3
 # Obtained via
 # print(tokenizer.chat_template.replace("}\n", "####").replace("\n", "\\n").replace("####", "}\n"))
-gemma3_template = \
-"""{{ bos_token }}
+gemma3_template = """{{ bos_token }}
 {%- if messages[0]['role'] == 'system' -%}
     {%- if messages[0]['content'] is string -%}
         {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
@@ -984,8 +1078,7 @@
 """
 
 # Ollama from https://ollama.com/library/gemma3/blobs/e0a42594d802
-gemma3_ollama = \
-'''
+gemma3_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- range $i, $_ := .Messages }}
 {{- $last := eq (len (slice $.Messages $i)) 1 }}
@@ -1008,17 +1101,25 @@
 '''
 
 gemma3_template_eos_token = ""
-CHAT_TEMPLATES["gemma-3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gemma-3"] = None # No system message in Gemma-3
+CHAT_TEMPLATES["gemma-3"] = (
+    gemma3_template,
+    gemma3_template_eos_token,
+    False,
+    gemma3_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["gemma-3"] = None  # No system message in Gemma-3
 
-CHAT_TEMPLATES["gemma3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gemma3"] = None # No system message in Gemma-3
-pass
+CHAT_TEMPLATES["gemma3"] = (
+    gemma3_template,
+    gemma3_template_eos_token,
+    False,
+    gemma3_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["gemma3"] = None  # No system message in Gemma-3
 
 # =========================================== Qwen-3
 # Official Qwen-3 chat template (see https://ollama.com/library/qwen3/blobs/eb4402837c78)
-qwen3_template = \
-"""
+qwen3_template = """
 {%- if tools %}
     {{- '<|im_start|>system\n' }}
     {%- if messages[0].role == 'system' %}
@@ -1120,8 +1221,7 @@
 """
 
 # Ollama template for Qwen-3 (see https://ollama.com/library/qwen3/blobs/eb4402837c78)
-qwen3_ollama = \
-'''
+qwen3_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- if .Messages }}
 {{- if or .System .Tools }}<|im_start|>system
@@ -1183,18 +1283,26 @@
 '''
 
 qwen3_template_eos_token = "<|im_end|>"
-CHAT_TEMPLATES["qwen-3"] = (qwen3_template, qwen3_template_eos_token, False, qwen3_ollama,)
-DEFAULT_SYSTEM_MESSAGE["qwen-3"] = None # No default system message for Qwen-3
+CHAT_TEMPLATES["qwen-3"] = (
+    qwen3_template,
+    qwen3_template_eos_token,
+    False,
+    qwen3_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["qwen-3"] = None  # No default system message for Qwen-3
 
-CHAT_TEMPLATES["qwen3"] = (qwen3_template, qwen3_template_eos_token, False, qwen3_ollama,)
-DEFAULT_SYSTEM_MESSAGE["qwen3"] = None # No default system message for Qwen-3
-pass
+CHAT_TEMPLATES["qwen3"] = (
+    qwen3_template,
+    qwen3_template_eos_token,
+    False,
+    qwen3_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["qwen3"] = None  # No default system message for Qwen-3
 
 # =========================================== Gemma-3n
 # Obtained via
 # print(tokenizer.chat_template.replace("}\n", "####").replace("\n", "\\n").replace("####", "}\n"))
-gemma3n_template = \
-"""{{ bos_token }}
+gemma3n_template = """{{ bos_token }}
 {%- if messages[0]['role'] == 'system' -%}
     {%- if messages[0]['content'] is string -%}
         {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
@@ -1239,8 +1347,7 @@
 """
 
 # Ollama from https://ollama.com/library/gemma3n/blobs/e0a42594d802
-gemma3n_ollama = \
-'''
+gemma3n_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- range $i, $_ := .Messages }}
 {{- $last := eq (len (slice $.Messages $i)) 1 }}
@@ -1256,18 +1363,26 @@
 '''
 
 gemma3n_template_eos_token = ""
-CHAT_TEMPLATES["gemma-3n"] = (gemma3n_template, gemma3n_template_eos_token, False, gemma3n_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gemma-3n"] = None # No system message in Gemma-3n
+CHAT_TEMPLATES["gemma-3n"] = (
+    gemma3n_template,
+    gemma3n_template_eos_token,
+    False,
+    gemma3n_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["gemma-3n"] = None  # No system message in Gemma-3n
 
-CHAT_TEMPLATES["gemma3n"] = (gemma3n_template, gemma3n_template_eos_token, False, gemma3n_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gemma3n"] = None # No system message in Gemma-3n
-pass
+CHAT_TEMPLATES["gemma3n"] = (
+    gemma3n_template,
+    gemma3n_template_eos_token,
+    False,
+    gemma3n_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["gemma3n"] = None  # No system message in Gemma-3n
 
 # =========================================== GPT-OSS
 # Obtained via
 # print(tokenizer.chat_template.replace("}\n", "####").replace("\n", "\\n").replace("####", "}\n"))
-gptoss_template = \
-"""{#-
+gptoss_template = """{#-
   In addition to the normal inputs of `messages` and `tools`, this template also accepts the
   following kwargs:
   - "builtin_tools": A list, can contain "browser" and/or "python".
@@ -1617,8 +1732,7 @@
 {%- endif -%}"""
 
 # Ollama from https://ollama.com/library/gemma3n/blobs/e0a42594d802
-gptoss_ollama = \
-'''
+gptoss_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
 Knowledge cutoff: 2024-06
@@ -1799,16 +1913,24 @@
 '''
 
 gptoss_template_template_eos_token = "<|return|>"
-CHAT_TEMPLATES["gpt-oss"] = (gptoss_template, gptoss_template_template_eos_token, False, gptoss_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gpt-oss"] = None # No system message in GPT-oss
+CHAT_TEMPLATES["gpt-oss"] = (
+    gptoss_template,
+    gptoss_template_template_eos_token,
+    False,
+    gptoss_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["gpt-oss"] = None  # No system message in GPT-oss
 
-CHAT_TEMPLATES["gptoss"] = (gptoss_template, gptoss_template_template_eos_token, False, gptoss_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gptoss"] = None # No system message in GPT-oss
-pass
+CHAT_TEMPLATES["gptoss"] = (
+    gptoss_template,
+    gptoss_template_template_eos_token,
+    False,
+    gptoss_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["gptoss"] = None  # No system message in GPT-oss
 
 # =========================================== Qwen3-Instruct
-qwen3_instruct_template = \
-'''{%- if tools %}
+qwen3_instruct_template = """{%- if tools %}
     {{- '<|im_start|>system\\n' }}
     {%- if messages[0].role == 'system' %}
         {{- messages[0].content + '\\n\\n' }}
@@ -1893,11 +2015,10 @@
 {%- endfor %}
 {%- if add_generation_prompt %}
     {{- '<|im_start|>assistant\\n' }}
-{%- endif %}'''
+{%- endif %}"""
 
 # Ollama from https://ollama.com/library/qwen3/blobs/53e4ea15e8f5
-qwen3_ollama = \
-'''
+qwen3_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """
 {{- $lastUserIdx := -1 -}}
@@ -1954,14 +2075,17 @@
 '''
 
 qwen3_template_eos_token = "<|im_end|>"
-CHAT_TEMPLATES["qwen3-instruct"] = (qwen3_instruct_template, qwen3_template_eos_token, False, qwen3_ollama,)
-DEFAULT_SYSTEM_MESSAGE["qwen3-instruct"] = None # No system message in Qwen3
+CHAT_TEMPLATES["qwen3-instruct"] = (
+    qwen3_instruct_template,
+    qwen3_template_eos_token,
+    False,
+    qwen3_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["qwen3-instruct"] = None  # No system message in Qwen3
 
-pass
 
 # =========================================== Qwen3-Thinking
-qwen3_thinking_template = \
-'''{%- if tools %}
+qwen3_thinking_template = """{%- if tools %}
     {{- '<|im_start|>system\\n' }}
     {%- if messages[0].role == 'system' %}
         {{- messages[0].content + '\\n\\n' }}
@@ -2046,31 +2170,37 @@
 {%- endfor %}
 {%- if add_generation_prompt %}
     {{- '<|im_start|>assistant\n\n' }}
-{%- endif %}'''
+{%- endif %}"""
 
-CHAT_TEMPLATES["qwen3-thinking"] = (qwen3_thinking_template, qwen3_template_eos_token, False, qwen3_ollama,)
-DEFAULT_SYSTEM_MESSAGE["qwen3-thinking"] = None # No system message in Qwen3
+CHAT_TEMPLATES["qwen3-thinking"] = (
+    qwen3_thinking_template,
+    qwen3_template_eos_token,
+    False,
+    qwen3_ollama,
+)
+DEFAULT_SYSTEM_MESSAGE["qwen3-thinking"] = None  # No system message in Qwen3
 
-pass
 
 # =========================================== Liquid-LFM2
-liquid_lfm2_template = \
-'''
+liquid_lfm2_template = """
 {{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + '
 ' + message['content'] + '<|im_end|>' + '
 '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
-' }}{% endif %}'''
+' }}{% endif %}"""
 
 liquid_lfm2_template_eos_token = "<|im_end|>"
-CHAT_TEMPLATES["lfm-2"] = (liquid_lfm2_template, liquid_lfm2_template_eos_token, False, None)
-DEFAULT_SYSTEM_MESSAGE["lfm-2"] = None # No system message in Phi-3
+CHAT_TEMPLATES["lfm-2"] = (
+    liquid_lfm2_template,
+    liquid_lfm2_template_eos_token,
+    False,
+    None,
+)
+DEFAULT_SYSTEM_MESSAGE["lfm-2"] = None  # No system message in Phi-3
 
-pass
 
 # =========================================== Starling-LM
 
-starling_template = \
-"""{{ bos_token }}
+starling_template = """{{ bos_token }}
 {%- for message in messages %}
     {{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>' }}
 {%- endfor %}
@@ -2079,8 +2209,7 @@
 {%- endif %}"""
 
 # Ollama from https://ollama.com/library/starling-lm:7b/blobs/4b21bfc435b4
-starling_ollama = \
-'''
+starling_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>
 {{ end }}{{ if .Prompt }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>
@@ -2094,15 +2223,18 @@
 '''
 
 starling_template_eos_token = "<|end_of_turn|>"
-CHAT_TEMPLATES["starling"] = (starling_template, starling_template_eos_token, False, starling_ollama)
+CHAT_TEMPLATES["starling"] = (
+    starling_template,
+    starling_template_eos_token,
+    False,
+    starling_ollama,
+)
 DEFAULT_SYSTEM_MESSAGE["starling"] = None
 
-pass
 
 # =========================================== Yi-chat
 
-yi_chat_template = \
-"""
+yi_chat_template = """
 {% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '
 ' + message['content'] + '<|im_end|>' + '
 '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
@@ -2110,8 +2242,7 @@
 """
 
 # Ollama from https://ollama.com/library/yi:34b-chat/blobs/62fbfd9ed093
-yi_chat_ollama = \
-'''
+yi_chat_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|im_start|>system
 {{ .System }}<|im_end|>
@@ -2122,11 +2253,18 @@
 '''
 
 yi_chat_template_eos_token = "<|endoftext|>"
-CHAT_TEMPLATES["yi-chat"] = (yi_chat_template, yi_chat_template_eos_token, False, yi_chat_ollama)
+CHAT_TEMPLATES["yi-chat"] = (
+    yi_chat_template,
+    yi_chat_template_eos_token,
+    False,
+    yi_chat_ollama,
+)
 DEFAULT_SYSTEM_MESSAGE["yi-chat"] = None
-pass
 
-def _change_system_message(template: str, type_chat_template: str, system_message: str = None):
+
+def _change_system_message(
+    template: str, type_chat_template: str, system_message: str = None
+):
     system_message_pattern = r"\{system_message\}"
 
     # For predefined templates, check if default system message exists
@@ -2139,7 +2277,6 @@ def _change_system_message(template: str, type_chat_template: str, system_messag
                 "You need to manually add the system message in your data."
             )
         return template, system_message
-    pass
 
     # For custom templates
     if type_chat_template is None:
@@ -2147,36 +2284,43 @@ def _change_system_message(template: str, type_chat_template: str, system_messag
 
         if has_placeholder:
             if system_message is None:
-                raise ValueError("Unsloth: You need to provide a system message for custom templates.")
+                raise ValueError(
+                    "Unsloth: You need to provide a system message for custom templates."
+                )
             new_template = re.sub(system_message_pattern, system_message, template)
             return new_template, system_message
 
         return template, system_message
-    pass
 
     # For predefined templates with default system message
-    message_to_use = system_message if system_message is not None else default_system_message
+    message_to_use = (
+        system_message if system_message is not None else default_system_message
+    )
     new_template = re.sub(system_message_pattern, message_to_use, template)
 
     return new_template, message_to_use
-pass
 
 
 def get_chat_template(
     tokenizer,
     chat_template = "chatml",
-    mapping = {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"},
+    mapping = {
+        "role": "role",
+        "content": "content",
+        "user": "user",
+        "assistant": "assistant",
+    },
     map_eos_token = True,
     system_message = None,
 ):
-    assert(type(map_eos_token) is bool)
+    assert type(map_eos_token) is bool
     old_tokenizer = tokenizer
 
     IS_GEMMA = False
     if tokenizer.__class__.__name__.startswith("Gemma"):
-        if chat_template == "chatml": chat_template = "gemma_chatml"
+        if chat_template == "chatml":
+            chat_template = "gemma_chatml"
         IS_GEMMA = True
-    pass
 
     # We add a check for Llama-3
     # if chat_template == "llama-3":
@@ -2196,32 +2340,42 @@ def get_chat_template(
     same_padding_token = False
     type_chat_template = None
 
-    if type(chat_template) in (list, tuple,):
+    if type(chat_template) in (
+        list,
+        tuple,
+    ):
         # For changing system message later
         # Since it's not supported yet, we will raise an error first!
         type_chat_template = chat_template[0].lower()
         chat_template, stop_word = chat_template
-        assert(type(chat_template) is str)
-        assert(type(stop_word) is str)
+        assert type(chat_template) is str
+        assert type(stop_word) is str
         ollama_modelfile = None
 
     elif type(chat_template) is str:
         # For changing system message later
         type_chat_template = chat_template.lower()
 
-        chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]
+        chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[
+            chat_template
+        ]
 
         # Check mapping to eos_token
-        if not map_eos_token and yes_map_eos_token: map_eos_token = True
-        if not yes_map_eos_token and map_eos_token: map_eos_token = False
-
-        if type(stop_word) in (list, tuple,):
+        if not map_eos_token and yes_map_eos_token:
+            map_eos_token = True
+        if not yes_map_eos_token and map_eos_token:
+            map_eos_token = False
+
+        if type(stop_word) in (
+            list,
+            tuple,
+        ):
             token_mapping, stop_word = stop_word
-            assert(type(token_mapping) is dict)
+            assert type(token_mapping) is dict
         else:
             token_mapping = None
 
-        assert(type(stop_word) is str)
+        assert type(stop_word) is str
 
         # Check fast tokenizer
         if not is_fast_tokenizer:
@@ -2248,15 +2402,16 @@ def get_chat_template(
                 elif old_count == 0:
                     raise RuntimeError(f"{old_token} was not part of the tokenizer!")
                 else:
-                    string_vocab = string_vocab.replace(f'"{old_token}"', f'"{new_token}"')
-                pass
-            pass
+                    string_vocab = string_vocab.replace(
+                        f'"{old_token}"', f'"{new_token}"'
+                    )
 
             if map_eos_token and (not stop_word in token_mapping.values()):
                 # Do not map 107 = <|im_end|> and 1 = <|im_end|>. This will reduce the vocab size by 1
-                logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
+                logger.warning_once(
+                    f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}."
+                )
                 string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
-            pass
 
             if skipped != len(token_mapping):
                 new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
@@ -2266,7 +2421,6 @@ def get_chat_template(
                 if old_pad_token == tokenizer.eos_token:
                     old_pad_token = stop_word
                     same_padding_token = True
-                pass
 
                 if map_eos_token:
                     new_tokenizer = tokenizer.__class__(
@@ -2279,15 +2433,18 @@ def get_chat_template(
                         tokenizer_object = new_tokenizer,
                         pad_token = old_pad_token,
                     )
-                pass
 
                 # Must fix the sentence piece tokenizer since there's no tokenizer.model file!
-                tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
-            else:
-                pass
+                tokenizer = fix_sentencepiece_tokenizer(
+                    tokenizer,
+                    new_tokenizer,
+                    token_mapping,
+                )
 
         elif map_eos_token and (stop_word != "eos_token"):
-            logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
+            logger.warning_once(
+                f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}."
+            )
 
             # Replaces the old EOS token with a new one.
             # Useful for ChatML <|im_end|> for example.
@@ -2311,14 +2468,12 @@ def get_chat_template(
                 string_vocab = string_vocab.replace(temporary_stop_token, stop_word)
             else:
                 string_vocab = string_vocab.replace(old_eos_token, stop_word)
-            pass
             new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
 
             # Careful on pad_token
             if old_pad_token == old_eos_token:
                 old_pad_token = stop_word
                 same_padding_token = True
-            pass
 
             new_tokenizer = tokenizer.__class__(
                 tokenizer_object = new_tokenizer,
@@ -2329,46 +2484,59 @@ def get_chat_template(
             )
 
             # Must fix the sentence piece tokenizer since there's no tokenizer.model file!
-            token_mapping = { old_eos_token : stop_word, }
-            tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
-        pass
+            token_mapping = {
+                old_eos_token: stop_word,
+            }
+            tokenizer = fix_sentencepiece_tokenizer(
+                tokenizer,
+                new_tokenizer,
+                token_mapping,
+            )
 
     else:
         raise TypeError(
-            f"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\n"\
+            f"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\n"
             f"{CHAT_TEMPLATES.keys()}"
         )
-    pass
 
     # Careful on Gemma
     # bos_token is a must or else losses become too high
-    if IS_GEMMA and not chat_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
+    if IS_GEMMA and not chat_template.startswith(
+        ("{{ bos_token }}", "{{- bos_token }}")
+    ):
         chat_template = "{{ bos_token }}" + chat_template
-    pass
 
     # For ShareGPT role -> from and content -> value
-    new_chat_template = chat_template\
-        .replace("'role'",      "'" + mapping["role"]      + "'")\
-        .replace("'content'",   "'" + mapping["content"]   + "'")\
-        .replace("'user'",      "'" + mapping["user"]      + "'")\
+    new_chat_template = (
+        chat_template.replace("'role'", "'" + mapping["role"] + "'")
+        .replace("'content'", "'" + mapping["content"] + "'")
+        .replace("'user'", "'" + mapping["user"] + "'")
         .replace("'assistant'", "'" + mapping["assistant"] + "'")
+    )
 
     _, tokenizer = patch_tokenizer(model = None, tokenizer = tokenizer)
     tokenizer.padding_side = old_padding_side
 
     # If not normal HF, we add a check to make old templates work
-    if mapping != {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"}:
-        chat_template = \
-            "{% if 'role' in messages[0] %}" + \
-            chat_template + \
-            "{% else %}" + \
-            new_chat_template + \
-            "{% endif %}"
+    if mapping != {
+        "role": "role",
+        "content": "content",
+        "user": "user",
+        "assistant": "assistant",
+    }:
+        chat_template = (
+            "{% if 'role' in messages[0] %}"
+            + chat_template
+            + "{% else %}"
+            + new_chat_template
+            + "{% endif %}"
+        )
     else:
         chat_template = new_chat_template
-    pass
 
-    chat_template, system_message = _change_system_message(chat_template, type_chat_template, system_message)
+    chat_template, system_message = _change_system_message(
+        chat_template, type_chat_template, system_message
+    )
 
     tokenizer.chat_template = chat_template
 
@@ -2376,14 +2544,16 @@ def get_chat_template(
     old_pad_token = getattr(old_tokenizer, "pad_token", None)
     old_bos_token = getattr(old_tokenizer, "bos_token", None)
     old_unk_token = getattr(old_tokenizer, "unk_token", None)
-    new_pad_token = getattr(tokenizer,     "pad_token", None)
-    new_bos_token = getattr(tokenizer,     "bos_token", None)
-    new_unk_token = getattr(tokenizer,     "unk_token", None)
-    if old_bos_token != new_bos_token: tokenizer.bos_token = old_bos_token
-    if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token
+    new_pad_token = getattr(tokenizer, "pad_token", None)
+    new_bos_token = getattr(tokenizer, "bos_token", None)
+    new_unk_token = getattr(tokenizer, "unk_token", None)
+    if old_bos_token != new_bos_token:
+        tokenizer.bos_token = old_bos_token
+    if old_unk_token != new_unk_token:
+        tokenizer.unk_token = old_unk_token
     if not same_padding_token:
-        if old_pad_token != new_pad_token: tokenizer.pad_token = old_pad_token
-    pass
+        if old_pad_token != new_pad_token:
+            tokenizer.pad_token = old_pad_token
 
     # stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
 
@@ -2392,18 +2562,15 @@ def get_chat_template(
 
     # Add Ollama
     tokenizer._ollama_modelfile = ollama_modelfile
-    tokenizer._system_message   = system_message
-    return tokenizer#, stopping_criteria
-pass
+    tokenizer._system_message = system_message
+    return tokenizer  # , stopping_criteria
 
 
 def remove_special_tokens(tokenizer, prompt):
     # Removes double BOS token
     if prompt.startswith(tokenizer.bos_token):
-        prompt = prompt[len(tokenizer.bos_token):]
-    pass
+        prompt = prompt[len(tokenizer.bos_token) :]
     return prompt
-pass
 
 
 def _parse_combined_prompt(combined_prompt, dataset):
@@ -2413,14 +2580,14 @@ def _parse_combined_prompt(combined_prompt, dataset):
     for column in possible_columns:
         if column not in dataset_columns:
             raise KeyError(
-                f"Unsloth: Your prompt includes '{column}' but this does not exist in the dataset. "\
+                f"Unsloth: Your prompt includes '{column}' but this does not exist in the dataset. "
                 f"Only allowed columns are {list(dataset_columns)}"
             )
-        pass
-    pass
 
     # Find [[...]]
-    optional_prompts = list(re.finditer(r"\[\[.+?\]\]", combined_prompt, flags = re.DOTALL | re.MULTILINE))
+    optional_prompts = list(
+        re.finditer(r"\[\[.+?\]\]", combined_prompt, flags = re.DOTALL | re.MULTILINE)
+    )
     optional_prompts = [(x.span(), x.group(0)) for x in optional_prompts]
 
     final_optional_prompts = []
@@ -2428,30 +2595,32 @@ def _parse_combined_prompt(combined_prompt, dataset):
         # Add left
         left = optional_prompts[0]
         l = left[0][0]
-        if l != 0: final_optional_prompts.append(combined_prompt[:l])
+        if l != 0:
+            final_optional_prompts.append(combined_prompt[:l])
 
         # Add in between
         for left, right in zip(optional_prompts[:-1], optional_prompts[1:]):
             l, r = left[0][-1], right[0][0]
             final_optional_prompts.append(left)
-            if l != r: final_optional_prompts.append(combined_prompt[l : r])
-        pass
+            if l != r:
+                final_optional_prompts.append(combined_prompt[l:r])
         final_optional_prompts.append(optional_prompts[-1])
 
         # Add right
         right = optional_prompts[-1]
         r = right[0][1]
-        if r != len(combined_prompt): final_optional_prompts.append(combined_prompt[r:])
+        if r != len(combined_prompt):
+            final_optional_prompts.append(combined_prompt[r:])
     else:
         # Just add in the entire string
         final_optional_prompts.append(combined_prompt)
-    pass
 
-    check_combined = "".join(x if type(x) is str else x[1] for x in final_optional_prompts)
-    assert(combined_prompt == check_combined)
+    check_combined = "".join(
+        x if type(x) is str else x[1] for x in final_optional_prompts
+    )
+    assert combined_prompt == check_combined
 
     return possible_columns, final_optional_prompts
-pass
 
 
 def _create_formatter(possible_columns, final_optional_prompts, user_column_name):
@@ -2459,9 +2628,11 @@ def _create_formatter(possible_columns, final_optional_prompts, user_column_name
     function = ["def __combined_prompt_processor__(examples):"]
     columns = list(set(possible_columns))
     for column in columns:
-        function.append(f"{' '*4}{column}__ = examples['{column}']")
-    function.append(f"{' '*4}texts = []")
-    function.append(f"{' '*4}for ({', '.join(columns)}) in zip({', '.join(f'{x}__' for x in columns)}):")
+        function.append(f"{' ' * 4}{column}__ = examples['{column}']")
+    function.append(f"{' ' * 4}texts = []")
+    function.append(
+        f"{' ' * 4}for ({', '.join(columns)}) in zip({', '.join(f'{x}__' for x in columns)}):"
+    )
 
     # Add optional tags as well!
     final_prompt = ""
@@ -2472,27 +2643,37 @@ def _create_formatter(possible_columns, final_optional_prompts, user_column_name
             columns = re.findall(r"\{(.+?)\}", optional_prompt)
             formatter += columns
             # Must escape \n \r
-            final_prompt += optional_prompt.encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
+            final_prompt += (
+                optional_prompt.encode("unicode-escape")
+                .decode("utf-8")
+                .replace("'", "\\'")
+                .replace('"', '\\"')
+            )
         else:
             where, prompt = optional_prompt
             # Strip [[...]]
             # Must escape \n \r
-            prompt = prompt[2:-2].encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
+            prompt = (
+                prompt[2:-2]
+                .encode("unicode-escape")
+                .decode("utf-8")
+                .replace("'", "\\'")
+                .replace('"', '\\"')
+            )
             columns = re.findall(r"\{(.+?)\}", prompt)
             x = f"__optional_{j}__"
-            prompt = f"{' '*8}{x} = '{prompt}'.format({', '.join(f'{x} = {x}' for x in columns)}) if {columns[0]} else ''"
+            prompt = f"{' ' * 8}{x} = '{prompt}'.format({', '.join(f'{x} = {x}' for x in columns)}) if {columns[0]} else ''"
             function.append(prompt)
             formatter.append(x)
             final_prompt += "{" + x + "}"
-        pass
-    pass
 
-    function.insert(1, f"{' '*4}__combined_prompt__ = '{final_prompt}'")
-    function.append(f"{' '*8}texts.append("\
-                    f"__combined_prompt__.format({', '.join(f'{x} = {x}' for x in formatter)}))")
-    function.append(f"{' '*4}return " + "{ " + f"'{user_column_name}' : texts" + " }")
+    function.insert(1, f"{' ' * 4}__combined_prompt__ = '{final_prompt}'")
+    function.append(
+        f"{' ' * 8}texts.append("
+        f"__combined_prompt__.format({', '.join(f'{x} = {x}' for x in formatter)}))"
+    )
+    function.append(f"{' ' * 4}return " + "{ " + f"'{user_column_name}' : texts" + " }")
     return "\n".join(function)
-pass
 
 
 def to_sharegpt(
@@ -2521,27 +2702,34 @@ def to_sharegpt(
     if "conversations" in dataset.column_names:
         convo = dataset[0]["conversations"]
         if type(convo) is list:
-            raise TypeError("Unsloth: Your dataset is probably already in ShareGPT format!")
-        pass
-    pass
+            raise TypeError(
+                "Unsloth: Your dataset is probably already in ShareGPT format!"
+            )
 
-    possible_columns, final_optional_prompts = _parse_combined_prompt(merged_prompt, dataset)
-    function = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)
+    possible_columns, final_optional_prompts = _parse_combined_prompt(
+        merged_prompt, dataset
+    )
+    function = _create_formatter(
+        possible_columns, final_optional_prompts, merged_column_name
+    )
     exec(function, globals())
-    dataset = dataset.map(__combined_prompt_processor__, batched = True, desc = "Merging columns")
+    dataset = dataset.map(
+        __combined_prompt_processor__, batched = True, desc = "Merging columns"
+    )
 
     def __convert_to_sharegpt__(examples):
-        users      = examples[merged_column_name]
+        users = examples[merged_column_name]
         assistants = examples[output_column_name]
         texts = [
             [
-                {"from" : "human", "value" : str(user)     },
-                {"from" : "gpt",   "value" : str(assistant)},
-            ] \
+                {"from": "human", "value": str(user)},
+                {"from": "gpt", "value": str(assistant)},
+            ]
             for user, assistant in zip(users, assistants)
         ]
-        return { "conversations" : texts, }
-    pass
+        return {
+            "conversations": texts,
+        }
 
     dataset = dataset.map(
         __convert_to_sharegpt__,
@@ -2553,28 +2741,35 @@ def __convert_to_sharegpt__(examples):
 
     # Randomnly concat conversations to create a long stream!
     from datasets import concatenate_datasets
-    n_extensions = max(conversation_extension-1, 0)
-    if n_extensions == 0: return dataset
 
-    dataset = dataset.rename_columns({"conversations" : "conversations0"})
+    n_extensions = max(conversation_extension - 1, 0)
+    if n_extensions == 0:
+        return dataset
+
+    dataset = dataset.rename_columns({"conversations": "conversations0"})
     all_shuffled = [dataset]
-    for j in range(1, n_extensions+1):
-        shuffled = dataset.shuffle(seed = random_state+j).rename_columns({"conversations0" : f"conversations{j}"})
+    for j in range(1, n_extensions + 1):
+        shuffled = dataset.shuffle(seed = random_state + j).rename_columns(
+            {"conversations0": f"conversations{j}"}
+        )
         all_shuffled.append(shuffled)
-    pass
     dataset = concatenate_datasets(all_shuffled, axis = 1)
 
     # Combine them into 1
     function = "def __combine_conversations__(examples):\n"
     n_extensions += 1
     for j in range(n_extensions):
-        function += f"{' '*4}conversations{j}__ = examples['conversations{j}']\n"
-    function += f"{' '*4}convos = []\n"
-    function += f"{' '*4}for ({', '.join(f'conversations{j}' for j in range(n_extensions))}) "\
-                f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n"
-    function += f"{' '*8}convos.append("\
-                f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n"
-    function += f"{' '*4}return " + "{ " + "'conversations' : convos" + " }"
+        function += f"{' ' * 4}conversations{j}__ = examples['conversations{j}']\n"
+    function += f"{' ' * 4}convos = []\n"
+    function += (
+        f"{' ' * 4}for ({', '.join(f'conversations{j}' for j in range(n_extensions))}) "
+        f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n"
+    )
+    function += (
+        f"{' ' * 8}convos.append("
+        f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n"
+    )
+    function += f"{' ' * 4}return " + "{ " + "'conversations' : convos" + " }"
 
     # Map function
     exec(function, globals())
@@ -2586,7 +2781,6 @@ def __convert_to_sharegpt__(examples):
         remove_columns = dataset.column_names if remove_unused_columns else None,
     )
     return dataset
-pass
 
 
 def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
@@ -2598,53 +2792,53 @@ def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
 
     # Remove BOS
     if getattr(tokenizer, "bos_token", None) is not None:
-        added_tokens_decoder = [x for x in added_tokens_decoder if x != tokenizer.bos_token]
-    pass
+        added_tokens_decoder = [
+            x for x in added_tokens_decoder if x != tokenizer.bos_token
+        ]
 
     repeatted_tokens = []
     # Join all vocab
     joined_text = "\x01\x00".join(added_tokens_decoder)
     for token in added_tokens_decoder:
         n = len(token)
-        repeatted_counts = joined_text.count(token[:n//2])
+        repeatted_counts = joined_text.count(token[: n // 2])
         # Try finding longer than 1/2 of the token in the rest
         # For eg <|reserved_special_token_0|>, <|reserved_special_token_1|>
         if repeatted_counts > 2:
-            for j in range(n//2+1, n):
+            for j in range(n // 2 + 1, n):
                 if joined_text.count(token[:j]) < repeatted_counts:
                     j -= 1
                     # Remove repeatted tokens to reduce search space
                     joined_text = joined_text.replace(token[:j], "")
                     repeatted_tokens.append(token[:j])
                     break
-            pass
-        pass
-    pass
 
     # Remove duplicates
     splitted = joined_text.split("\x01\x00")
-    final_eos_tokens = [old for old, new in zip(added_tokens_decoder, splitted) if old == new]
+    final_eos_tokens = [
+        old for old, new in zip(added_tokens_decoder, splitted) if old == new
+    ]
     final_eos_tokens += extra_eos_tokens
     final_eos_tokens += repeatted_tokens
 
     # Remove new lines, spaces and HTML tags
     filtered_eos_tokens = []
     for token in final_eos_tokens:
-        if   token.count("\n") == len(token): continue
-        elif token.count("ā") == len(token): continue
-        elif token.startswith("<") and len(token) <= 2: continue
-        elif token.startswith("") and len(token) == 3: continue
+        if token.count("\n") == len(token):
+            continue
+        elif token.count("ā") == len(token):
+            continue
+        elif token.startswith("<") and len(token) <= 2:
+            continue
+        elif token.startswith("") and len(token) == 3:
+            continue
         filtered_eos_tokens.append(token)
-    pass
     return filtered_eos_tokens
-pass
-
 
-def construct_chat_template( \
 
-tokenizer = None,
-
-chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
+def construct_chat_template(
+    tokenizer = None,
+    chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
 
 {SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
 
@@ -2655,11 +2849,8 @@ def construct_chat_template( \
 {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
 
 {OUTPUT}<|eot_id|>""",
-
-default_system_message = \
-    "Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
-
-extra_eos_tokens = None,
+    default_system_message = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
+    extra_eos_tokens = None,
 ):
     """
     Creates a Ollama modelfile and a HF Jinja template from a custom
@@ -2671,27 +2862,32 @@ def construct_chat_template( \
     # Strip only the left
     chat_template = chat_template.lstrip()
 
-    assert(tokenizer is not None)
+    assert tokenizer is not None
 
-    if extra_eos_tokens is None: extra_eos_tokens = []
-    elif type(extra_eos_tokens) is str: extra_eos_tokens = [extra_eos_tokens,]
+    if extra_eos_tokens is None:
+        extra_eos_tokens = []
+    elif type(extra_eos_tokens) is str:
+        extra_eos_tokens = [
+            extra_eos_tokens,
+        ]
 
     vocab = tokenizer.get_vocab()
     for extra_eos in extra_eos_tokens:
-        assert(type(extra_eos) is str)
+        assert type(extra_eos) is str
         if extra_eos not in vocab:
-            raise ValueError(f"Unsloth: `{extra_eos}` is not a singular token in the tokenizer.")
-        pass
-    pass
-
-    error_msg = \
-        "Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} "\
-        "and the assistant output {OUTPUT}\n\n"\
-        "For example what is not allowed is just:\n"\
-        "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n\n\n"\
-        "What is required is 2x of this:\n"\
-        "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"\
+            raise ValueError(
+                f"Unsloth: `{extra_eos}` is not a singular token in the tokenizer."
+            )
+
+    error_msg = (
+        "Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} "
+        "and the assistant output {OUTPUT}\n\n"
+        "For example what is not allowed is just:\n"
+        "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n\n\n"
+        "What is required is 2x of this:\n"
         "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"
+        "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"
+    )
 
     # Check for EOS after {OUTPUT}
     if tokenizer.eos_token is not None:
@@ -2700,71 +2896,78 @@ def construct_chat_template( \
         raise RuntimeError(
             "Unsloth: Your tokenizer does not have an EOS token? Please provide one via extra_eos_tokens!"
         )
-    pass
 
     # Check tokenizer types
     tokenizer_name = tokenizer.name_or_path.lower()
-    if tokenizer_name.startswith(("unsloth/llama-3-8b-instruct", "unsloth/llama-3-70b-instruct")):
+    if tokenizer_name.startswith(
+        ("unsloth/llama-3-8b-instruct", "unsloth/llama-3-70b-instruct")
+    ):
         # Add <|eot_id|>
         extra_eos_tokens.append("<|eot_id|>")
-    elif ("<|eot_id|>" in extra_eos_tokens or "<|eot_id|>" in chat_template) and \
-        tokenizer_name.startswith(("unsloth/llama-3-8b", "unsloth/llama-3-70b")):
+    elif (
+        "<|eot_id|>" in extra_eos_tokens or "<|eot_id|>" in chat_template
+    ) and tokenizer_name.startswith(("unsloth/llama-3-8b", "unsloth/llama-3-70b")):
         # Warn
         logger.warning(
-            "Unsloth: Base llama-3 models did not train <|eot_id|>.\n"\
+            "Unsloth: Base llama-3 models did not train <|eot_id|>.\n"
             "Please use the instruct version or use <|end_of_text|>"
         )
-    pass
     extra_eos_tokens = list(set(extra_eos_tokens))
 
     count_eos = 0
     for eos in extra_eos_tokens:
         count_eos += len(re.findall(r"{OUTPUT}" + re.escape(eos), chat_template))
-    pass
 
     # This forces you to provide 2 input and outputs
     final_combined_check = False
 
     try:
         # O(N^2) search finding 2 repeatted pieces of text
-        j = len(chat_template)-1
+        j = len(chat_template) - 1
         at_least_one = False
         while j > 0:
             found = chat_template.rfind(chat_template[j:], 0, j)
-            if found == -1: break
+            if found == -1:
+                break
             j -= 1
             at_least_one = True
-        pass
-        if j > 0: j += 1
-        else: raise RuntimeError(error_msg)
+        if j > 0:
+            j += 1
+        else:
+            raise RuntimeError(error_msg)
 
-        if not at_least_one: raise RuntimeError(error_msg)
+        if not at_least_one:
+            raise RuntimeError(error_msg)
 
         # Must be equivalent to left
         final_combined_check = True
 
         # Repeatted text
         instruction_response = chat_template[j:]
-        if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1:
+        if (
+            instruction_response.count("{INPUT}") != 1
+            or instruction_response.count("{OUTPUT}") != 1
+        ):
             raise RuntimeError(error_msg)
-        pass
 
         # 1st System, Instruction, Output pair
-        left  = chat_template[:j]
+        left = chat_template[:j]
         # 2nd Instruction, Output pair
         right = chat_template[j:]
 
         final_combined_check = left if final_combined_check else chat_template
 
         # Isolate input
-        extra_eos_tokens_regex = "|".join(f"(?:{re.escape(x)})" for x in extra_eos_tokens)
+        extra_eos_tokens_regex = "|".join(
+            f"(?:{re.escape(x)})" for x in extra_eos_tokens
+        )
         if len(extra_eos_tokens_regex) != 0:
             find_end = f"(?:{extra_eos_tokens_regex})?"
         else:
             find_end = ""
         find_end = r"\{INPUT\}[\s\n]{0,}" + find_end
         input_end = list(re.finditer(find_end, right))
-        assert(len(input_end) == 1)
+        assert len(input_end) == 1
         input_end = input_end[0]
         input_end = input_end.span(0)[1]
         input_part = right[:input_end]
@@ -2774,52 +2977,65 @@ def construct_chat_template( \
 
         # Isolate system
         where_system = left.find(input_part)
-        system_part = left[:where_system if where_system != -1 else len(left)]
+        system_part = left[: where_system if where_system != -1 else len(left)]
 
         # Check if the user provided a correct prompt
         combined = system_part + input_part + output_part
         if combined != final_combined_check:
-            combined_changed = combined            .replace('\n', '\\n')
-            left_changed     = final_combined_check.replace('\n', '\\n')
+            combined_changed = combined.replace("\n", "\\n")
+            left_changed = final_combined_check.replace("\n", "\\n")
             raise RuntimeError(
-                "Unsloth: The prompt template you provided isn't correct. You gave:\n"\
-                f"{combined_changed}\n\n"\
-                "But we require the following:\n"\
+                "Unsloth: The prompt template you provided isn't correct. You gave:\n"
+                f"{combined_changed}\n\n"
+                "But we require the following:\n"
                 f"{left_changed}"
             )
-        pass
     except:
-        ending = chat_template[chat_template.find("{OUTPUT}") + len("{OUTPUT}"):]
+        ending = chat_template[chat_template.find("{OUTPUT}") + len("{OUTPUT}") :]
 
         ending = re.escape(ending)
         find_text = "{INPUT}" + ending + "(.+?{OUTPUT}" + ending + ")"
-        response_part = re.findall(find_text, chat_template, flags = re.DOTALL | re.MULTILINE)
+        response_part = re.findall(
+            find_text, chat_template, flags = re.DOTALL | re.MULTILINE
+        )
         response_part = response_part[0]
 
         for j in range(1, len(response_part)):
             try_find = re.escape(response_part[:j])
-            try: found = next(re.finditer("(" + try_find + ").+?\\{INPUT\\}", chat_template, flags = re.DOTALL | re.MULTILINE))
-            except: break
-        pass
+            try:
+                found = next(
+                    re.finditer(
+                        "(" + try_find + ").+?\\{INPUT\\}",
+                        chat_template,
+                        flags = re.DOTALL | re.MULTILINE,
+                    )
+                )
+            except:
+                break
         separator = found.group(1)
 
         response_start = chat_template.find(response_part)
         start_instruction = chat_template[:response_start].rfind(separator)
-        if start_instruction == -1: start_instruction = 0
+        if start_instruction == -1:
+            start_instruction = 0
         instruction_part = chat_template[start_instruction:response_start]
 
         combined = instruction_part + response_part
         where = chat_template.find(combined)
         system_part = chat_template[:where]
 
-        system_part, input_part, output_part = system_part, instruction_part, response_part
-    pass
+        system_part, input_part, output_part = (
+            system_part,
+            instruction_part,
+            response_part,
+        )
 
     if count_eos == 0:
-        logger.warning("Unsloth: We automatically added an EOS token to stop endless generations.")
+        logger.warning(
+            "Unsloth: We automatically added an EOS token to stop endless generations."
+        )
         eos = extra_eos_tokens[0]
         output_part = output_part + eos
-    pass
 
     # Ollama modelfile parts
 
@@ -2831,60 +3047,73 @@ def construct_chat_template( \
         always_bos_token = True
         if ollama_system.startswith(tokenizer.bos_token):
             has_bos_token = True
-            ollama_system = ollama_system[len(tokenizer.bos_token):]
-        pass
-    pass
+            ollama_system = ollama_system[len(tokenizer.bos_token) :]
     # Check system
     if "{SYSTEM}" in ollama_system:
-        system_modelfile = "{{ if .System }}" + ollama_system.replace("{SYSTEM}", "{{ .System }}") + "{{ end }}"
+        system_modelfile = (
+            "{{ if .System }}"
+            + ollama_system.replace("{SYSTEM}", "{{ .System }}")
+            + "{{ end }}"
+        )
     else:
         system_modelfile = ollama_system
-    pass
-    input_modelfile  = "{{ if .Prompt }}" + input_part .replace("{INPUT}",  "{{ .Prompt }}") + "{{ end }}"
+    input_modelfile = (
+        "{{ if .Prompt }}"
+        + input_part.replace("{INPUT}", "{{ .Prompt }}")
+        + "{{ end }}"
+    )
     output_modelfile = output_part.replace("{OUTPUT}", "{{ .Response }}")
 
     # Ollama EOS
     ollama_eos = get_ollama_eos_tokens(tokenizer, extra_eos_tokens)
-    ollama_eos = '\n'.join(f'PARAMETER stop "{eos}"' for eos in ollama_eos)
+    ollama_eos = "\n".join(f'PARAMETER stop "{eos}"' for eos in ollama_eos)
 
     # Add temperature and min_p to counteract gibberish
     ollama_eos += "\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1"
 
     # Ollama modelfile
     part = '"""'
-    modelfile = 'FROM {__FILE_LOCATION__}\n\n'\
-    'TEMPLATE ' + part + system_modelfile + input_modelfile + output_modelfile + \
-        part + '\n\n' + ollama_eos
+    modelfile = (
+        "FROM {__FILE_LOCATION__}\n\n"
+        "TEMPLATE "
+        + part
+        + system_modelfile
+        + input_modelfile
+        + output_modelfile
+        + part
+        + "\n\n"
+        + ollama_eos
+    )
 
     # HF Jinja Chat template
     def process(part, which, content = "message['content']"):
         if part.endswith(which):
-            part = "'" + part[:part.find(which)] + f"' + {content}"
+            part = "'" + part[: part.find(which)] + f"' + {content}"
         elif part.startswith(which):
-            part = f"{content} + '" + part[part.find(which):] + "'"
+            part = f"{content} + '" + part[part.find(which) :] + "'"
         else:
             part = "'" + part.replace(which, f"' + {content} + '") + "'"
-        if part.startswith("'' + "): part = part[5:]
+        if part.startswith("'' + "):
+            part = part[5:]
         return part
-    pass
-    input_jinja  = process(input_part,  "{INPUT}")
+
+    input_jinja = process(input_part, "{INPUT}")
     output_jinja = process(output_part, "{OUTPUT}")
-    pass
-
-    jinja_template = \
-        "{% for message in loop_messages %}"\
-            "{% if message['role'] == 'user' %}"\
-                "{{ " + input_jinja + " }}"\
-            "{% elif message['role'] == 'assistant' %}"\
-                "{{ " + output_jinja + " }}"\
-            "{% else %}"\
-                "{{ raise_exception('Only user and assistant roles are supported!') }}"\
-            "{% endif %}"\
-        "{% endfor %}"\
-        "{% if add_generation_prompt %}"\
-            "{{ '" + output_part[:output_part.find("{OUTPUT}")] + "' }}"\
+
+    jinja_template = (
+        "{% for message in loop_messages %}"
+        "{% if message['role'] == 'user' %}"
+        "{{ " + input_jinja + " }}"
+        "{% elif message['role'] == 'assistant' %}"
+        "{{ " + output_jinja + " }}"
+        "{% else %}"
+        "{{ raise_exception('Only user and assistant roles are supported!') }}"
         "{% endif %}"
-    pass
+        "{% endfor %}"
+        "{% if add_generation_prompt %}"
+        "{{ '" + output_part[: output_part.find("{OUTPUT}")] + "' }}"
+        "{% endif %}"
+    )
 
     # Now add system prompt to jinja
     if len(system_part) != 0:
@@ -2894,73 +3123,72 @@ def process(part, which, content = "message['content']"):
         if "{SYSTEM}" in partial_system:
             if default_system_message is None:
                 raise RuntimeError("Unsloth: Please specify a default system message!")
-        pass
 
         # Separate the BOS
         if has_bos_token:
             partial_system = partial_system.replace(tokenizer.bos_token, "", 1)
-            system_part    = system_part   .replace(tokenizer.bos_token, "", 1)
-        pass
+            system_part = system_part.replace(tokenizer.bos_token, "", 1)
 
-        partial_system = \
-            "{% if messages[0]['role'] == 'system' %}"\
-                "{{ " + partial_system + " }}"\
-                "{% set loop_messages = messages[1:] %}"
+        partial_system = (
+            "{% if messages[0]['role'] == 'system' %}"
+            "{{ " + partial_system + " }}"
+            "{% set loop_messages = messages[1:] %}"
+        )
         if default_system_message is not None:
             full_system = system_part.replace("{SYSTEM}", default_system_message)
             if "{SYSTEM}" in system_part:
                 modelfile += '\nSYSTEM "' + default_system_message + '"'
-            pass
-            partial_system += "{% else %}"\
-                "{{ '" + full_system + "' }}"\
-                "{% set loop_messages = messages %}"\
-            "{% endif %}"
+            partial_system += (
+                "{% else %}"
+                "{{ '" + full_system + "' }}"
+                "{% set loop_messages = messages %}"
+                "{% endif %}"
+            )
         else:
             partial_system += "{% endif %}"
-        pass
 
         jinja_template = partial_system + jinja_template
 
         if has_bos_token:
             jinja_template = "{{ bos_token }}" + jinja_template
-    pass
 
     # Fix missing loop_messages
     if "{% set loop_messages = messages %}" not in jinja_template:
         jinja_template = jinja_template.replace(
             "{% for message in loop_messages %}",
             "{% for message in messages %}",
-            1, # Only replace the first one
+            1,  # Only replace the first one
         )
-    pass
 
     # Check if system part is the same!
     jinja_template = re.sub(
-        r"\{\% if messages\[0\]\['role'\] \=\= 'system' \%\}\{\{ '(.+?)' \}\}"\
-        r"\{\% set loop\_messages \= messages\[1\:\] \%\}"\
-        r"\{\% else \%\}\{\{ '\1' \}\}\{\% set loop\_messages \= messages \%\}\{\% endif \%\}"\
+        r"\{\% if messages\[0\]\['role'\] \=\= 'system' \%\}\{\{ '(.+?)' \}\}"
+        r"\{\% set loop\_messages \= messages\[1\:\] \%\}"
+        r"\{\% else \%\}\{\{ '\1' \}\}\{\% set loop\_messages \= messages \%\}\{\% endif \%\}"
         r"\{\% for message in loop\_messages \%\}",
         r"{{ '\1' }}{% for message in messages %}",
-        jinja_template, flags = re.MULTILINE | re.DOTALL,
+        jinja_template,
+        flags = re.MULTILINE | re.DOTALL,
     )
 
     # Check jinja template for bos
     if always_bos_token:
         if not jinja_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
             jinja_template = "{{ bos_token }}" + jinja_template
-    pass
 
     # Get instruction and output parts for train_on_inputs = False
-    input_part  = input_part [:input_part .find("{INPUT}")]
-    output_part = output_part[:output_part.find("{OUTPUT}")]
+    input_part = input_part[: input_part.find("{INPUT}")]
+    output_part = output_part[: output_part.find("{OUTPUT}")]
     return modelfile, jinja_template, input_part, output_part
-pass
 
 
 def test_construct_chat_template():
     token = "hf_"
     from transformers import AutoTokenizer
-    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", token = token)
+
+    tokenizer = AutoTokenizer.from_pretrained(
+        "meta-llama/Meta-Llama-3-8B-Instruct", token = token
+    )
 
     chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
 
@@ -2974,8 +3202,7 @@ def test_construct_chat_template():
 
 {OUTPUT}<|eot_id|>"""
 
-    default_system_message = \
-        "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
+    default_system_message = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
 
     extra_eos_tokens = None
 
@@ -2993,21 +3220,21 @@ def test_construct_chat_template():
         {"role": "assistant", "content": "Anything else?"},
         {"role": "user", "content": "What's 2x2?"},
     ]
-    correct_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
+    correct_output = tokenizer.apply_chat_template(
+        messages, tokenize = False, add_generation_prompt = True
+    )
 
     tokenizer.chat_template = jinja_template
-    new_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
-    assert(correct_output == new_output)
-    pass
-pass
-
-
-def apply_chat_template( \
+    new_output = tokenizer.apply_chat_template(
+        messages, tokenize = False, add_generation_prompt = True
+    )
+    assert correct_output == new_output
 
-dataset,
-tokenizer = None,
 
-chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
+def apply_chat_template(
+    dataset,
+    tokenizer = None,
+    chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
 
 {SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
 
@@ -3018,12 +3245,8 @@ def apply_chat_template( \
 {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
 
 {OUTPUT}<|eot_id|>""",
-
-default_system_message = \
-    "Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
-
-extra_eos_tokens = None,
-
+    default_system_message = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
+    extra_eos_tokens = None,
 ):
     """
     Creates a Ollama modelfile and a HF Jinja template from a custom
@@ -3038,29 +3261,42 @@ def apply_chat_template( \
         default_system_message = default_system_message,
         extra_eos_tokens = extra_eos_tokens,
     )
+
     def formatting_prompts_func(examples):
         convos = examples["conversations"]
-        texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
-        return { "text" : texts, }
-    pass
+        texts = [
+            tokenizer.apply_chat_template(
+                convo, tokenize = False, add_generation_prompt = False
+            )
+            for convo in convos
+        ]
+        return {
+            "text": texts,
+        }
 
     tokenizer.chat_template = jinja_template
     tokenizer._ollama_modelfile = modelfile
-    tokenizer._unsloth_input_part  = input_part
+    tokenizer._unsloth_input_part = input_part
     tokenizer._unsloth_output_part = output_part
     if hasattr(tokenizer, "tokenizer"):
         tokenizer.tokenizer.chat_template = jinja_template
         tokenizer.tokenizer._ollama_modelfile = modelfile
-        tokenizer.tokenizer._unsloth_input_part  = input_part
+        tokenizer.tokenizer._unsloth_input_part = input_part
         tokenizer.tokenizer._unsloth_output_part = output_part
 
-    return dataset.map(formatting_prompts_func, batched = True,)
-pass
+    return dataset.map(
+        formatting_prompts_func,
+        batched = True,
+    )
 
 
 def create_stopping_criteria(tokenizer, stop_word = "eos_token"):
     class StoppingCriteriaSub(StoppingCriteria):
-        __slots__ = "stop_token", "single_match", "length",
+        __slots__ = (
+            "stop_token",
+            "single_match",
+            "length",
+        )
 
         def __init__(self, stops = "eos_token", device = "cuda", encounters = 1):
             super().__init__()
@@ -3068,31 +3304,36 @@ def __init__(self, stops = "eos_token", device = "cuda", encounters = 1):
                 self.stop_token = torch.tensor(tokenizer.eos_token_id, device = "cuda")
                 self.length = 1
             else:
-                self.stop_token = tokenizer(["\n" + stops], add_special_tokens = False, return_tensors = "pt")
+                self.stop_token = tokenizer(
+                    ["\n" + stops], add_special_tokens = False, return_tensors = "pt"
+                )
                 self.stop_token = self.stop_token.input_ids.ravel()[1:].to("cuda")
                 self.length = self.stop_token.shape[0]
-            pass
             self.single_match = self.length == 1
-        pass
 
         def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> bool:
             input_ids = input_ids.ravel()
             last_token = input_ids[-1]
-            if self.single_match and (last_token == self.stop_token): return True
-
-            if input_ids.shape[0] >= self.length and \
-                (input_ids[-self.length:] == self.stop_token).all(): return True
+            if self.single_match and (last_token == self.stop_token):
+                return True
+
+            if (
+                input_ids.shape[0] >= self.length
+                and (input_ids[-self.length :] == self.stop_token).all()
+            ):
+                return True
             return False
-        pass
-    pass
+
     stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = stop_word)])
     return stopping_criteria
-pass
 
 
 def test_chat_templates():
     messages = [
-        {"role": "system","content": " You are a friendly chatbot.",},
+        {
+            "role": "system",
+            "content": " You are a friendly chatbot.",
+        },
         {"role": "user", "content": "What is 2+2?"},
         {"role": "assistant", "content": "It's 4."},
         {"role": "user", "content": "  But 2+2 is equal to 5. "},
@@ -3102,36 +3343,57 @@ def test_chat_templates():
 
     # Zephyr
     from transformers import AutoTokenizer
+
     template = zephyr_template
     correct_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
-    correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
+    correct_prompt = correct_tokenizer.apply_chat_template(
+        messages, tokenize = False, add_generation_prompt = True
+    )
     correct_tokenizer.chat_template = template
-    our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
-    assert(correct_prompt == our_prompt)
+    our_prompt = correct_tokenizer.apply_chat_template(
+        messages, tokenize = False, add_generation_prompt = True
+    )
+    assert correct_prompt == our_prompt
 
     # Chatml
     template = chatml_template
-    correct_tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
-    correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
+    correct_tokenizer = AutoTokenizer.from_pretrained(
+        "teknium/OpenHermes-2.5-Mistral-7B"
+    )
+    correct_prompt = correct_tokenizer.apply_chat_template(
+        messages, tokenize = False, add_generation_prompt = True
+    )
     correct_tokenizer.chat_template = template
-    our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
-    assert(correct_prompt == our_prompt)
+    our_prompt = correct_tokenizer.apply_chat_template(
+        messages, tokenize = False, add_generation_prompt = True
+    )
+    assert correct_prompt == our_prompt
 
     # Mistral
     template = mistral_template
-    correct_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
-    correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
+    correct_tokenizer = AutoTokenizer.from_pretrained(
+        "mistralai/Mistral-7B-Instruct-v0.2"
+    )
+    correct_prompt = correct_tokenizer.apply_chat_template(
+        messages[1:], tokenize = False, add_generation_prompt = True
+    )
     correct_tokenizer.chat_template = template
-    our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
-    assert(correct_prompt == our_prompt)
+    our_prompt = correct_tokenizer.apply_chat_template(
+        messages[1:], tokenize = False, add_generation_prompt = True
+    )
+    assert correct_prompt == our_prompt
 
     # Llama
     template = llama_template
     correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-2-7b-chat")
-    correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
+    correct_prompt = correct_tokenizer.apply_chat_template(
+        messages, tokenize = False, add_generation_prompt = True
+    )
     correct_tokenizer.chat_template = template
-    our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
-    assert(correct_prompt == our_prompt)
+    our_prompt = correct_tokenizer.apply_chat_template(
+        messages, tokenize = False, add_generation_prompt = True
+    )
+    assert correct_prompt == our_prompt
 
     # Vicuna
     try:
@@ -3140,16 +3402,20 @@ def test_chat_templates():
         os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
         from fastchat.conversation import get_conv_template
     correct_prompt = get_conv_template("vicuna_v1.1")
-    for j in range(len(messages)-1):
-        correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
+    for j in range(len(messages) - 1):
+        correct_prompt.append_message(
+            correct_prompt.roles[j % 2 == 1], messages[j + 1]["content"]
+        )
     correct_prompt.append_message(correct_prompt.roles[1], "")
     correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
 
     template = vicuna_template
     correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
     correct_tokenizer.chat_template = template
-    our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
-    assert(correct_prompt == our_prompt)
+    our_prompt = correct_tokenizer.apply_chat_template(
+        messages[1:], tokenize = False, add_generation_prompt = True
+    )
+    assert correct_prompt == our_prompt
 
     try:
         from fastchat.conversation import get_conv_template
@@ -3157,50 +3423,68 @@ def test_chat_templates():
         os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
         from fastchat.conversation import get_conv_template
     correct_prompt = get_conv_template("zero_shot")
-    for j in range(len(messages)-1):
-        correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
+    for j in range(len(messages) - 1):
+        correct_prompt.append_message(
+            correct_prompt.roles[j % 2 == 1], messages[j + 1]["content"]
+        )
     correct_prompt.append_message(correct_prompt.roles[1], "")
     correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
 
     template = vicuna_old_template
     correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
     correct_tokenizer.chat_template = template
-    our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
+    our_prompt = correct_tokenizer.apply_chat_template(
+        messages[1:], tokenize = False, add_generation_prompt = True
+    )
     # We add  ourselves
-    assert(correct_prompt == our_prompt.replace("", ""))
+    assert correct_prompt == our_prompt.replace("", "")
 
     # Gemma
     correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-7b-it")
-    correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
+    correct_prompt = correct_tokenizer.apply_chat_template(
+        messages[1:], tokenize = False, add_generation_prompt = True
+    )
     correct_tokenizer.chat_template = gemma_template
-    our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
-    assert(our_prompt == correct_prompt)
+    our_prompt = correct_tokenizer.apply_chat_template(
+        messages[1:], tokenize = False, add_generation_prompt = True
+    )
+    assert our_prompt == correct_prompt
 
     # Llama-3
     template = llama3_template
     correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct")
-    correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
+    correct_prompt = correct_tokenizer.apply_chat_template(
+        messages, tokenize = False, add_generation_prompt = True
+    )
     correct_tokenizer.chat_template = template
-    our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
-    assert(correct_prompt == our_prompt)
+    our_prompt = correct_tokenizer.apply_chat_template(
+        messages, tokenize = False, add_generation_prompt = True
+    )
+    assert correct_prompt == our_prompt
 
     # Phi-3
     template = phi3_template
-    correct_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
-    correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
+    correct_tokenizer = AutoTokenizer.from_pretrained(
+        "microsoft/Phi-3-mini-4k-instruct"
+    )
+    correct_prompt = correct_tokenizer.apply_chat_template(
+        messages[1:], tokenize = False, add_generation_prompt = True
+    )
     correct_tokenizer.chat_template = template
-    our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
-    assert(correct_prompt == our_prompt)
-pass
+    our_prompt = correct_tokenizer.apply_chat_template(
+        messages[1:], tokenize = False, add_generation_prompt = True
+    )
+    assert correct_prompt == our_prompt
 
 
 def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf"):
     """
-        Carefully checks the output of GGUF's tokenization and HF.
-        Can catch all tokenization bugs.
+    Carefully checks the output of GGUF's tokenization and HF.
+    Can catch all tokenization bugs.
     """
     import subprocess
     import re
+
     messages = [
         {"role": "user", "content": "What is 2+2?"},
         {"role": "assistant", "content": "It's 4."},
@@ -3219,33 +3503,51 @@ def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf")
 
     ### Response:
     {}""".format(
-        "Describe the city given eloquently.", # instruction
-        "The lost city of Atlantis.", # input
-        "", # output - leave this blank for generation!
+        "Describe the city given eloquently.",  # instruction
+        "The lost city of Atlantis.",  # input
+        "",  # output - leave this blank for generation!
     )
-    prompts = [ prompt, ]
+    prompts = [
+        prompt,
+    ]
 
     if tokenizer.chat_template is not None:
-        prompt = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
-        prompt = prompt.replace("'", "") # Subprocess does not like ''
+        prompt = tokenizer.apply_chat_template(
+            messages, tokenize = False, add_generation_prompt = True
+        )
+        prompt = prompt.replace("'", "")  # Subprocess does not like ''
         prompt = remove_special_tokens(tokenizer, prompt)
         prompts.append(prompt)
-    pass
 
     for prompt in prompts:
-        command = f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
+        command = (
+            f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "
             f"--check-tensors -p '{prompt}'"
+        )
 
         datas = []
-        with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
+        with subprocess.Popen(
+            command,
+            shell = True,
+            stdout = subprocess.PIPE,
+            stderr = subprocess.STDOUT,
+            bufsize = 1,
+        ) as sp:
             for line in sp.stdout:
                 datas.append(line.decode("utf-8", errors = "replace"))
-        pass
         gguf_tokens = "".join(datas)
 
         # Now extract GGUF tokenization attempt
-        gguf_tokenized = re.findall(r"([\d]{1,}) \-\> \'([^\']{1,})\'", gguf_tokens, flags = re.MULTILINE)
-        gguf_tokenized = [(int(x[0]), x[1],) for x in gguf_tokenized]
+        gguf_tokenized = re.findall(
+            r"([\d]{1,}) \-\> \'([^\']{1,})\'", gguf_tokens, flags = re.MULTILINE
+        )
+        gguf_tokenized = [
+            (
+                int(x[0]),
+                x[1],
+            )
+            for x in gguf_tokenized
+        ]
         input_ids = tokenizer(prompt).input_ids
 
         tokens = tokenizer.batch_decode(input_ids)
@@ -3253,7 +3555,7 @@ def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf")
 
         # Compare to Huggingface
         for j, (hf_token, gguf_token) in enumerate(zip(hf_tokenized, gguf_tokenized)):
-            if (hf_token[0] != gguf_token[0]):
+            if hf_token[0] != gguf_token[0]:
                 print("Failed GGUF != HF at", j)
                 print("HF =", hf_token)
                 print("GGUF =", gguf_token)
@@ -3262,7 +3564,4 @@ def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf")
                 print(gguf_tokenized)
                 print()
                 raise RuntimeError("Failed comparing GGUF to HF.")
-            pass
-        pass
     return True
-pass
diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py
index e70c6b50a..7735b1ecd 100644
--- a/unsloth/dataprep/synthetic.py
+++ b/unsloth/dataprep/synthetic.py
@@ -20,6 +20,7 @@
 from collections import deque
 import time
 import os
+
 os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 import requests
 import torch
@@ -31,51 +32,65 @@
     patch_vllm,
     delete_vllm,
 )
-from unsloth_zoo.log import logger 
+from unsloth_zoo.log import logger
 import numpy as np
 
 from .synthetic_configs import (
     synthetic_qa_config,
 )
 
-def terminate_tree(proc: subprocess.Popen, timeout=15):
+
+def terminate_tree(proc: subprocess.Popen, timeout = 15):
     if proc is None or proc.poll() is not None:
         return
-    
+
     try:
         import psutil
+
         parent = psutil.Process(proc.pid)
-        for child in parent.children(recursive=True):
+        for child in parent.children(recursive = True):
             child.terminate()
         parent.terminate()
-        parent.wait(timeout=timeout/2)
+        parent.wait(timeout = timeout / 2)
         return
     except:
         pass
-    
-    if os.name == 'nt':
+
+    if os.name == "nt":
         try:
             subprocess.run(
-                ['taskkill', '/T', '/F', '/PID', str(proc.pid)],
-                capture_output=True,
-                timeout=5
+                ["taskkill", "/T", "/F", "/PID", str(proc.pid)],
+                capture_output = True,
+                timeout = 5,
             )
-            proc.wait(timeout=1)
+            proc.wait(timeout = 1)
             return
         except:
             pass
-    
+
     proc.kill()
     try:
-        proc.wait(timeout=5)
+        proc.wait(timeout = 5)
     except:
         pass
 
+
 class PipeCapture:
     """Non blocking pipe capture"""
-    def __init__(self, pipe, keep_lines=2000, echo=False, name="", text=True, encoding='utf-8', errors='replace', ready_regex=None):
+
+    def __init__(
+        self,
+        pipe,
+        keep_lines = 2000,
+        echo = False,
+        name = "",
+        text = True,
+        encoding = "utf-8",
+        errors = "replace",
+        ready_regex = None,
+    ):
         self.pipe = pipe
-        self.buf = deque(maxlen=keep_lines)
+        self.buf = deque(maxlen = keep_lines)
         self.lock = threading.Lock()
         self.echo = echo
         self.name = name
@@ -92,18 +107,18 @@ def __init__(self, pipe, keep_lines=2000, echo=False, name="", text=True, encodi
                 ready_regex = re.compile(ready_regex)
             self.ready_regex = ready_regex
 
-        self.t = threading.Thread(target=self._reader, daemon=True)
+        self.t = threading.Thread(target = self._reader, daemon = True)
         self.t.start()
 
     def _reader(self):
         try:
-            sentinel = '' if self.text else b''
+            sentinel = "" if self.text else b""
             for raw_line in iter(self.pipe.readline, sentinel):
                 if not self.text:
                     line = raw_line.decode(self.encoding, self.errors)
                 else:
                     line = raw_line
-                line = line.rstrip('\r\n')
+                line = line.rstrip("\r\n")
                 if self.echo:
                     if "platform is" not in line:
                         print(f"{self.name}: {line}")
@@ -115,22 +130,25 @@ def _reader(self):
                     self.ready_event.set()
 
         finally:
-            try: self.pipe.close()
-            except Exception: pass
+            try:
+                self.pipe.close()
+            except Exception:
+                pass
             self.closed_event.set()
 
-    def wait_for_ready(self, timeout=None):
+    def wait_for_ready(self, timeout = None):
         return self.ready_event.wait(timeout)
 
     def has_closed(self):
         return self.closed_event.is_set()
 
-    def wait_until_closed(self, timeout=None):
+    def wait_until_closed(self, timeout = None):
         return self.closed_event.wait(timeout)
 
-    def tail(self, n=200):
+    def tail(self, n = 200):
         with self.lock:
-            return '\n'.join(list(self.buf)[-n:])
+            return "\n".join(list(self.buf)[-n:])
+
 
 class SyntheticDataKit:
     def __init__(
@@ -144,17 +162,18 @@ def __init__(
         timeout = 1200,  # maybe this is not enough for large models if we need to download
         **kwargs,
     ):
-        assert(type(model_name) is str)
-        assert(type(max_seq_length) is int)
-        assert(type(gpu_memory_utilization) is float)
-        assert(type(float8_kv_cache) is bool)
-        assert(type(conservativeness) is float)
-        assert(token is None or type(token) is str)
+        assert type(model_name) is str
+        assert type(max_seq_length) is int
+        assert type(gpu_memory_utilization) is float
+        assert type(float8_kv_cache) is bool
+        assert type(conservativeness) is float
+        assert token is None or type(token) is str
 
         self.model_name = model_name
         self.max_seq_length = max_seq_length
 
         from transformers import AutoConfig, AutoTokenizer
+
         self.config = AutoConfig.from_pretrained(
             model_name,
             token = token,
@@ -165,24 +184,27 @@ def __init__(
         )
         patch_vllm(debug = False)
         engine_args = load_vllm(
-            model_name             = model_name,
-            config                 = self.config,
+            model_name = model_name,
+            config = self.config,
             gpu_memory_utilization = gpu_memory_utilization,
-            max_seq_length         = max_seq_length,
-            disable_log_stats      = True,
-            float8_kv_cache        = float8_kv_cache,
-            conservativeness       = conservativeness,
-            return_args            = True,
-            enable_lora            = False,
-            use_bitsandbytes       = False,
-            compilation_config     = 3,
+            max_seq_length = max_seq_length,
+            disable_log_stats = True,
+            float8_kv_cache = float8_kv_cache,
+            conservativeness = conservativeness,
+            return_args = True,
+            enable_lora = False,
+            use_bitsandbytes = False,
+            compilation_config = 3,
             **kwargs,
         )
         if "dtype" in engine_args:
             dtype_val = engine_args["dtype"]
-            if   dtype_val == torch.float16:  dtype_val = "float16"
-            elif dtype_val == torch.bfloat16: dtype_val = "bfloat16"
-            elif dtype_val == torch.float32:  dtype_val = "float32"
+            if dtype_val == torch.float16:
+                dtype_val = "float16"
+            elif dtype_val == torch.bfloat16:
+                dtype_val = "bfloat16"
+            elif dtype_val == torch.float32:
+                dtype_val = "float32"
             engine_args["dtype"] = dtype_val
             # Convert torch.bfloat16, torch.float16, etc. to valid CLI string
             if hasattr(dtype_val, "name"):
@@ -193,11 +215,15 @@ def __init__(
             valid_dtypes = {"auto", "bfloat16", "float", "float16", "float32", "half"}
             if engine_args["dtype"] not in valid_dtypes:
                 engine_args["dtype"] = "auto"
-        if "device" in engine_args: del engine_args["device"]
-        if "model"  in engine_args: del engine_args["model"]
+        if "device" in engine_args:
+            del engine_args["device"]
+        if "model" in engine_args:
+            del engine_args["model"]
 
         subprocess_commands = [
-            "vllm", "serve", str(model_name),
+            "vllm",
+            "serve",
+            str(model_name),
         ]
         for key, value in engine_args.items():
             flag = key.replace("_", "-")
@@ -209,7 +235,9 @@ def __init__(
             which = str(value).replace("torch.", "")
             if which == "True":
                 # Ignore --enforce-eager True
-                subprocess_commands += ["--" + flag,]
+                subprocess_commands += [
+                    "--" + flag,
+                ]
             elif which == "False":
                 # Ignore flag
                 pass
@@ -217,8 +245,10 @@ def __init__(
                 # Ignore flag
                 pass
             else:
-                subprocess_commands += ["--" + flag, which,]
-        pass
+                subprocess_commands += [
+                    "--" + flag,
+                    which,
+                ]
         logger.info(subprocess_commands)
         vllm_process = subprocess.Popen(
             subprocess_commands,
@@ -228,12 +258,22 @@ def __init__(
         )
         ready_re = re.compile(r"Starting vLLM API server(?:\s+\d+)?\s+on\b")
         self.vllm_process = vllm_process
-        self.stdout_capture = PipeCapture(vllm_process.stdout, keep_lines = 1000,
-                                          echo = True, name = "vLLM STDOUT",
-                                          ready_regex = ready_re, text = False)
-        self.stderr_capture = PipeCapture(vllm_process.stderr, keep_lines = 2000,
-                                          echo = False, name = "vLLM STDERR",
-                                          ready_regex = None, text = False)
+        self.stdout_capture = PipeCapture(
+            vllm_process.stdout,
+            keep_lines = 1000,
+            echo = True,
+            name = "vLLM STDOUT",
+            ready_regex = ready_re,
+            text = False,
+        )
+        self.stderr_capture = PipeCapture(
+            vllm_process.stderr,
+            keep_lines = 2000,
+            echo = False,
+            name = "vLLM STDERR",
+            ready_regex = None,
+            text = False,
+        )
         # we don't print stderr to console but self.stderr_capture.tail(200) will print the last 200 lines
 
         ready = self.stdout_capture.wait_for_ready(timeout = timeout)
@@ -250,7 +290,6 @@ def __init__(
             return
         else:
             print("vLLM Server Ready Detected")
-        pass
 
         trial = 0
         while not self.check_vllm_status():
@@ -263,7 +302,6 @@ def __init__(
             trial += 1
             time.sleep(1)
         return
-    pass
 
     @staticmethod
     def from_pretrained(
@@ -284,7 +322,6 @@ def from_pretrained(
             token = token,
             **kwargs,
         )
-    pass
 
     @staticmethod
     def check_vllm_status():
@@ -294,58 +331,68 @@ def check_vllm_status():
                 return True
         except requests.exceptions.ConnectionError:
             return False
-        pass
-    pass
 
     def cleanup(self):
-        if not hasattr(self, "vllm_process"): return
+        if not hasattr(self, "vllm_process"):
+            return
 
         vllm_process = self.vllm_process
         print("Attempting to terminate the VLLM server gracefully...")
         try:
             vllm_process.terminate()
-            vllm_process.wait(timeout=10)
+            vllm_process.wait(timeout = 10)
             print("Server terminated gracefully.")
         except subprocess.TimeoutExpired:
-            print("Server did not terminate gracefully after 10 seconds. Forcing kill...")
+            print(
+                "Server did not terminate gracefully after 10 seconds. Forcing kill..."
+            )
             vllm_process.kill()
             vllm_process.wait()
             print("Server killed forcefully.")
         except Exception as e:
-             print(f"An error occurred while trying to stop the process: {e}")
-             try:
-                 if vllm_process.poll() is None:
-                     print("Attempting forceful kill due to error...")
-                     vllm_process.kill()
-                     vllm_process.wait()
-                     print("Server killed forcefully after error.")
-             except Exception as kill_e:
-                 print(f"Error during forceful kill: {kill_e}")
+            print(f"An error occurred while trying to stop the process: {e}")
+            try:
+                if vllm_process.poll() is None:
+                    print("Attempting forceful kill due to error...")
+                    vllm_process.kill()
+                    vllm_process.wait()
+                    print("Server killed forcefully after error.")
+            except Exception as kill_e:
+                print(f"Error during forceful kill: {kill_e}")
         for _ in range(10):
             torch.cuda.empty_cache()
             gc.collect()
 
         # Delete vLLM module as well
         delete_vllm(llm = None)
-    pass
 
-    def __enter__(self): return self
-    def __exit__(self, *exc): self.cleanup()
-    def __del__(self): self.cleanup()
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *exc):
+        self.cleanup()
+
+    def __del__(self):
+        self.cleanup()
 
     def chunk_data(self, filename = None):
         # Chunks data by max tokens and generation length
-        assert(filename is not None)
-        assert(os.path.exists(filename))
-        assert(hasattr(self, "tokenizer"))
+        assert filename is not None
+        assert os.path.exists(filename)
+        assert hasattr(self, "tokenizer")
         if not hasattr(self, "max_seq_length"):
-            raise RuntimeError("Please use SynthetidDataKit.from_pretrained(...) first!")
+            raise RuntimeError(
+                "Please use SynthetidDataKit.from_pretrained(...) first!"
+            )
         if not hasattr(self, "overlap") or not hasattr(self, "max_generation_tokens"):
             raise RuntimeError("Please use prepare_qa_generation first!")
 
-        with open(filename, "r", encoding = "utf-8") as f: text = f.read()
+        with open(filename, "r", encoding = "utf-8") as f:
+            text = f.read()
 
-        max_tokens = self.max_seq_length - self.max_generation_tokens*2 - 128 # -128 to reduce errors
+        max_tokens = (
+            self.max_seq_length - self.max_generation_tokens * 2 - 128
+        )  # -128 to reduce errors
         if max_tokens <= 5:
             raise RuntimeError("Generation length is way too long!")
         input_ids = self.tokenizer(text, add_special_tokens = False).input_ids
@@ -353,23 +400,25 @@ def chunk_data(self, filename = None):
         # Get left and right boundaries
         length = len(input_ids)
         n_chunks = int(np.ceil(length / (max_tokens - self.overlap)))
-        boundaries = np.ceil(np.linspace(0, length - self.overlap, n_chunks)).astype(int)
+        boundaries = np.ceil(np.linspace(0, length - self.overlap, n_chunks)).astype(
+            int
+        )
         boundaries = np.stack((boundaries[:-1], (boundaries + self.overlap)[1:])).T
         boundaries = np.minimum(boundaries, length).tolist()
 
         # Get extension of filename like .txt
         filename, extension = os.path.splitext(filename)
-        if filename.endswith("/"): filename = filename[:-1]
+        if filename.endswith("/"):
+            filename = filename[:-1]
 
         all_filenames = []
         for i, (left, right) in enumerate(boundaries):
-            chunked_text = self.tokenizer.decode(input_ids[left : right])
+            chunked_text = self.tokenizer.decode(input_ids[left:right])
             new_filename = f"{filename}_{i}{extension}"
             all_filenames.append(new_filename)
-            with open(new_filename, "w", encoding = "utf-8") as f: f.write(chunked_text)
-        pass
+            with open(new_filename, "w", encoding = "utf-8") as f:
+                f.write(chunked_text)
         return all_filenames
-    pass
 
     def prepare_qa_generation(
         self,
@@ -383,33 +432,34 @@ def prepare_qa_generation(
         cleanup_batch_size = 4,
         cleanup_temperature = 0.3,
     ):
-        assert(hasattr(self, "model_name"))
-        assert(hasattr(self, "max_seq_length"))
-        assert(max_generation_tokens < self.max_seq_length)
+        assert hasattr(self, "model_name")
+        assert hasattr(self, "max_seq_length")
+        assert max_generation_tokens < self.max_seq_length
 
         locations = "pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final"
         locations = locations.split(",")
         for path in locations:
             os.makedirs(os.path.join(output_folder, path), exist_ok = True)
-        pass
 
         self.max_generation_tokens = max_generation_tokens
 
-        config = synthetic_qa_config\
-            .replace("{data_output_location}", str(output_folder))\
-            .replace("{model_name}", str(self.model_name))\
-            .replace("{temperature}", str(temperature))\
-            .replace("{top_p}", str(top_p))\
-            .replace("{chunk_size}", str(self.max_seq_length - max_generation_tokens*2 - 2))\
-            .replace("{overlap}", str(overlap))\
-            .replace("{max_tokens}", str(max_generation_tokens))\
-            .replace("{default_num_pairs}", str(default_num_pairs))\
-            .replace("{cleanup_threshold}", str(cleanup_threshold))\
-            .replace("{cleanup_batch_size}", str(cleanup_batch_size))\
+        config = (
+            synthetic_qa_config.replace("{data_output_location}", str(output_folder))
+            .replace("{model_name}", str(self.model_name))
+            .replace("{temperature}", str(temperature))
+            .replace("{top_p}", str(top_p))
+            .replace(
+                "{chunk_size}", str(self.max_seq_length - max_generation_tokens * 2 - 2)
+            )
+            .replace("{overlap}", str(overlap))
+            .replace("{max_tokens}", str(max_generation_tokens))
+            .replace("{default_num_pairs}", str(default_num_pairs))
+            .replace("{cleanup_threshold}", str(cleanup_threshold))
+            .replace("{cleanup_batch_size}", str(cleanup_batch_size))
             .replace("{cleanup_temperature}", str(cleanup_temperature))
+        )
 
-        with open("synthetic_data_kit_config.yaml", "w", encoding = "utf-8") as f: f.write(config)
+        with open("synthetic_data_kit_config.yaml", "w", encoding = "utf-8") as f:
+            f.write(config)
 
         self.overlap = overlap
-    pass
-pass
diff --git a/unsloth/dataprep/synthetic_configs.py b/unsloth/dataprep/synthetic_configs.py
index f42817752..2e536467e 100644
--- a/unsloth/dataprep/synthetic_configs.py
+++ b/unsloth/dataprep/synthetic_configs.py
@@ -108,4 +108,4 @@
 
     DO NOT include any text outside of the JSON array, just return valid JSON:
 
-    {pairs}"""
\ No newline at end of file
+    {pairs}"""
diff --git a/unsloth/device_type.py b/unsloth/device_type.py
index ac70d2679..adc09b05d 100644
--- a/unsloth/device_type.py
+++ b/unsloth/device_type.py
@@ -27,10 +27,11 @@
 from unsloth_zoo.utils import Version
 import inspect
 
+
 @functools.cache
 def is_hip():
     return bool(getattr(getattr(torch, "version", None), "hip", None))
-pass
+
 
 @functools.cache
 def get_device_type():
@@ -43,20 +44,27 @@ def get_device_type():
     # Check torch.accelerator
     if hasattr(torch, "accelerator"):
         if not torch.accelerator.is_available():
-            raise NotImplementedError("Unsloth cannot find any torch accelerator? You need a GPU.")
+            raise NotImplementedError(
+                "Unsloth cannot find any torch accelerator? You need a GPU."
+            )
         accelerator = str(torch.accelerator.current_accelerator())
         if accelerator in ("cuda", "xpu", "hip"):
             raise RuntimeError(
-                f"Unsloth: Weirdly `torch.cuda.is_available()`, `torch.xpu.is_available()` and `is_hip` all failed.\n"\
-                f"But `torch.accelerator.current_accelerator()` works with it being = `{accelerator}`\n"\
+                f"Unsloth: Weirdly `torch.cuda.is_available()`, `torch.xpu.is_available()` and `is_hip` all failed.\n"
+                f"But `torch.accelerator.current_accelerator()` works with it being = `{accelerator}`\n"
                 f"Please reinstall torch - it's most likely broken :("
             )
-    raise NotImplementedError("Unsloth currently only works on NVIDIA, AMD and Intel GPUs.")
-pass
-DEVICE_TYPE : str = get_device_type()
+    raise NotImplementedError(
+        "Unsloth currently only works on NVIDIA, AMD and Intel GPUs."
+    )
+
+
+DEVICE_TYPE: str = get_device_type()
 # HIP fails for autocast and other torch functions. Use CUDA instead
 DEVICE_TYPE_TORCH = DEVICE_TYPE
-if DEVICE_TYPE_TORCH == "hip": DEVICE_TYPE_TORCH = "cuda"
+if DEVICE_TYPE_TORCH == "hip":
+    DEVICE_TYPE_TORCH = "cuda"
+
 
 @functools.cache
 def get_device_count():
@@ -66,22 +74,25 @@ def get_device_count():
         return torch.xpu.device_count()
     else:
         return 1
-pass
 
-DEVICE_COUNT : int = get_device_count()
+
+DEVICE_COUNT: int = get_device_count()
 
 # Check blocksize for 4bit -> 64 for CUDA, 128 for AMD
 # If AMD, we cannot load pre-quantized models for now :(
-ALLOW_PREQUANTIZED_MODELS : bool = True
+ALLOW_PREQUANTIZED_MODELS: bool = True
 # HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB
-ALLOW_BITSANDBYTES : bool = True
+ALLOW_BITSANDBYTES: bool = True
 if DEVICE_TYPE == "hip":
     try:
         from bitsandbytes.nn.modules import Params4bit
-        if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource(Params4bit):
+
+        if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource(
+            Params4bit
+        ):
             ALLOW_PREQUANTIZED_MODELS = False
         import bitsandbytes
+
         ALLOW_BITSANDBYTES = Version(bitsandbytes.__version__) > Version("0.48.2.dev0")
     except:
         pass
-pass
diff --git a/unsloth/import_fixes.py b/unsloth/import_fixes.py
index cac76909d..9ea9720cb 100644
--- a/unsloth/import_fixes.py
+++ b/unsloth/import_fixes.py
@@ -18,51 +18,74 @@
 from importlib.metadata import version as importlib_version
 from packaging.version import Version
 import logging
+
 UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
 
+
 # Ignore logging messages
 class HideLoggingMessage(logging.Filter):
-    __slots__ = "text",
-    def __init__(self, text): self.text = text
-    def filter(self, x): return not (self.text in x.getMessage())
-pass
+    __slots__ = ("text",)
+
+    def __init__(self, text):
+        self.text = text
+
+    def filter(self, x):
+        return not (self.text in x.getMessage())
+
 
 # Fix up AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
 # MUST do this at the start primarily due to tensorflow causing issues
 def fix_message_factory_issue():
     try:
         import google.protobuf.message_factory
+
         class MessageFactory:
-            def CreatePrototype(self, *args, **kwargs): return
-            def GetMessages(self, *args, **kwargs): return
-            def GetPrototype(self, *args, **kwargs): return
+            def CreatePrototype(self, *args, **kwargs):
+                return
+
+            def GetMessages(self, *args, **kwargs):
+                return
+
+            def GetPrototype(self, *args, **kwargs):
+                return
+
         if not hasattr(google.protobuf.message_factory, "MessageFactory"):
             if UNSLOTH_ENABLE_LOGGING:
                 print("Unsloth: Patching protobuf.MessageFactory as it doesn't exist")
             google.protobuf.message_factory.MessageFactory = MessageFactory
-        elif hasattr(google.protobuf.message_factory, "MessageFactory") and \
-            not hasattr(google.protobuf.message_factory.MessageFactory, "GetPrototype") and \
-            not hasattr(google.protobuf.message_factory, "GetMessageClass"):
+        elif (
+            hasattr(google.protobuf.message_factory, "MessageFactory")
+            and not hasattr(
+                google.protobuf.message_factory.MessageFactory, "GetPrototype"
+            )
+            and not hasattr(google.protobuf.message_factory, "GetMessageClass")
+        ):
             google.protobuf.message_factory.MessageFactory = MessageFactory
             if UNSLOTH_ENABLE_LOGGING:
                 print("Unsloth: Patching protobuf.MessageFactory as it doesn't exist")
-        elif hasattr(google.protobuf.message_factory, "MessageFactory") and \
-            not hasattr(google.protobuf.message_factory.MessageFactory, "GetPrototype") and \
-            hasattr(google.protobuf.message_factory, "GetMessageClass"):
+        elif (
+            hasattr(google.protobuf.message_factory, "MessageFactory")
+            and not hasattr(
+                google.protobuf.message_factory.MessageFactory, "GetPrototype"
+            )
+            and hasattr(google.protobuf.message_factory, "GetMessageClass")
+        ):
             GetMessageClass = google.protobuf.message_factory.GetMessageClass
+
             def GetPrototype(self, descriptor):
                 return GetMessageClass(descriptor)
+
             google.protobuf.message_factory.MessageFactory.GetPrototype = GetPrototype
             if UNSLOTH_ENABLE_LOGGING:
                 print("Unsloth: Patching protobuf.MessageFactory.GetPrototype")
-        pass
     except:
         pass
-pass
+
 
 # Fix Xformers performance issues since 0.0.25
 def fix_xformers_performance_issue():
-    if importlib.util.find_spec("xformers") is None: return
+    if importlib.util.find_spec("xformers") is None:
+        return
     xformers_version = importlib_version("xformers")
     if Version(xformers_version) < Version("0.0.29"):
         xformers_location = importlib.util.find_spec("xformers").origin
@@ -82,15 +105,18 @@ def fix_xformers_performance_issue():
                         f.write(text)
                         f.truncate()
                         if UNSLOTH_ENABLE_LOGGING:
-                            print("Unsloth: Patching Xformers to fix some performance issues.")
+                            print(
+                                "Unsloth: Patching Xformers to fix some performance issues."
+                            )
         except Exception as e:
             if UNSLOTH_ENABLE_LOGGING:
                 print(f"Unsloth: Failed patching Xformers with error = {str(e)}")
-pass
+
 
 # ValueError: 'aimv2' is already used by a Transformers config, pick another name.
 def fix_vllm_aimv2_issue():
-    if importlib.util.find_spec("vllm") is None: return
+    if importlib.util.find_spec("vllm") is None:
+        return
     vllm_version = importlib_version("vllm")
     if Version(vllm_version) < Version("0.10.1"):
         vllm_version = importlib.util.find_spec("vllm").origin
@@ -104,66 +130,72 @@ def fix_vllm_aimv2_issue():
                     if 'AutoConfig.register("aimv2", AIMv2Config)' in text:
                         text = text.replace(
                             'AutoConfig.register("aimv2", AIMv2Config)',
-                            '',
+                            "",
                         )
                         text = text.replace(
-                            '''backbone_config.pop('model_type')
+                            """backbone_config.pop('model_type')
                 backbone_config = AutoConfig.for_model(model_type,
-                                                       **backbone_config)''',
-                            '''if model_type != "aimv2":
+                                                       **backbone_config)""",
+                            """if model_type != "aimv2":
                     backbone_config.pop('model_type')
                     backbone_config = AutoConfig.for_model(model_type, **backbone_config)
                 else:
-                    backbone_config = AIMv2Config(**backbone_config)'''
+                    backbone_config = AIMv2Config(**backbone_config)""",
                         )
                         f.seek(0)
                         f.write(text)
                         f.truncate()
                         if UNSLOTH_ENABLE_LOGGING:
-                            print("Unsloth: Patching vLLM to fix `'aimv2' is already used by a Transformers config, pick another name.`")
+                            print(
+                                "Unsloth: Patching vLLM to fix `'aimv2' is already used by a Transformers config, pick another name.`"
+                            )
         except Exception as e:
             if UNSLOTH_ENABLE_LOGGING:
                 print(f"Unsloth: Failed patching vLLM with error = {str(e)}")
-pass
+
 
 def ignore_logger_messages():
     # Ignore Environment variable `HF_TOKEN` is set
     try:
         from huggingface_hub._login import logger as huggingface_hub_logger
+
         huggingface_hub_logger.addFilter(HideLoggingMessage("`HF_TOKEN`"))
         del huggingface_hub_logger
     except:
         pass
-pass
+
 
 def patch_ipykernel_hf_xet():
     # HF-XET == 1.1.10 and ipykernel == 7.0.0 causes issues
     # See https://github.com/huggingface/xet-core/issues/526
     # 2025-10-13T20:37:33.028737Z ERROR  Python exception updating progress:, error: PyErr { type: , value: LookupError(), traceback: Some() }, caller: "src/progress_update.rs:313"
     # at /home/runner/work/xet-core/xet-core/error_printer/src/lib.rs:28
-    if importlib.util.find_spec("hf_xet") is None: return
-    if importlib.util.find_spec("ipykernel") is None: return
-    if importlib.util.find_spec("huggingface_hub") is None: return
-    if (
-        Version(importlib_version("hf_xet")) == Version("1.1.10")
-    ) and (
+    if importlib.util.find_spec("hf_xet") is None:
+        return
+    if importlib.util.find_spec("ipykernel") is None:
+        return
+    if importlib.util.find_spec("huggingface_hub") is None:
+        return
+    if (Version(importlib_version("hf_xet")) == Version("1.1.10")) and (
         Version(importlib_version("ipykernel")) == Version("7.0.0")
     ):
         print(
-            "#### Unsloth: `hf_xet==1.1.10` and `ipykernel==7.0.0` breaks progress bars. Disabling for now in XET.\n"\
-            "#### Unsloth: To re-enable progress bars, please upgrade to `ipykernel>7.0.0` or wait for a fix to\n"\
+            "#### Unsloth: `hf_xet==1.1.10` and `ipykernel==7.0.0` breaks progress bars. Disabling for now in XET.\n"
+            "#### Unsloth: To re-enable progress bars, please upgrade to `ipykernel>7.0.0` or wait for a fix to\n"
             "https://github.com/huggingface/xet-core/issues/526"
         )
         from huggingface_hub.utils import disable_progress_bars
+
         disable_progress_bars()
-    pass
-pass
+
 
 def patch_trackio():
     # Set some environment variables to customize the Trackio dashboard for experiment tracking
     # See https://github.com/unslothai/notebooks/pull/110
-    os.environ["TRACKIO_LOGO_LIGHT_URL"] = "https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png"
-    os.environ["TRACKIO_LOGO_DARK_URL"] = "https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png"
+    os.environ["TRACKIO_LOGO_LIGHT_URL"] = (
+        "https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png"
+    )
+    os.environ["TRACKIO_LOGO_DARK_URL"] = (
+        "https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png"
+    )
     os.environ["TRACKIO_PLOT_ORDER"] = "train/reward"
-    pass
-pass
diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py
index 053a3b25f..15913413d 100644
--- a/unsloth/kernels/__init__.py
+++ b/unsloth/kernels/__init__.py
@@ -44,8 +44,14 @@
     apply_lora_o,
     fast_lora_forward,
 )
-from .fp8 import * # This step is to ensure that we patch the FbgmemFP8Linear and FP8Linear's forward functions before the execution of model creation so that this applies to compiled non fast inference models as well
-from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora
+from .fp8 import *  # This step is to ensure that we patch the FbgmemFP8Linear and FP8Linear's forward functions before the execution of model creation so that this applies to compiled non fast inference models as well
+from .utils import (
+    fast_dequantize,
+    fast_gemv,
+    QUANT_STATE,
+    fast_linear_forward,
+    matmul_lora,
+)
 
 from .flex_attention import (
     HAS_FLEX_ATTENTION,
@@ -56,11 +62,12 @@
 )
 
 import os
+
 if "UNSLOTH_ZOO_IS_PRESENT" not in os.environ:
     try:
-        print("𦄠Unsloth: Will patch your computer to enable 2x faster free finetuning.")
+        print(
+            "𦄠Unsloth: Will patch your computer to enable 2x faster free finetuning."
+        )
     except:
         print("Unsloth: Will patch your computer to enable 2x faster free finetuning.")
-    pass
-pass
 del os
diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py
index d3b618582..912e6f7e3 100644
--- a/unsloth/kernels/cross_entropy_loss.py
+++ b/unsloth/kernels/cross_entropy_loss.py
@@ -33,134 +33,145 @@
 
 
 def _cross_entropy_forward(
-    logits_ptr        ,
-    logits_row_stride ,
-    loss_ptr          ,
-    logsumexp_ptr     ,
-    labels_ptr        ,
-    VOCAB_SIZE        : tl.constexpr,
-    BLOCK_SIZE        : tl.constexpr,
-    DO_SOFTCAPPING    : tl.constexpr,
-    SOFTCAP           : tl.constexpr,
-    DO_LOGIT_SCALING  : tl.constexpr,
-    LOGIT_SCALE       : tl.constexpr,
+    logits_ptr,
+    logits_row_stride,
+    loss_ptr,
+    logsumexp_ptr,
+    labels_ptr,
+    VOCAB_SIZE: tl.constexpr,
+    BLOCK_SIZE: tl.constexpr,
+    DO_SOFTCAPPING: tl.constexpr,
+    SOFTCAP: tl.constexpr,
+    DO_LOGIT_SCALING: tl.constexpr,
+    LOGIT_SCALE: tl.constexpr,
 ):
     """
-        Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
-        Pi = exp(xi) / sum(exp(xi))
-        CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
-             = -y [ x - log[sum(exp(x))] ]
-             = y * (log[sum(exp(x))] - x)
-        If y == 0: CE_i = 0
-        If y == 1: CE_i = logsumexp - x
-
-        logsumexp is also stable
-        Take    y =         log[sum(exp(x))]
-           exp(y) =             sum(exp(x))
-           exp(y) =             sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
-           exp(y) =      exp(c)*sum(exp(x - c))
-               y  = log(exp(c)*sum(exp(x - c)))
-               y  = c + log[sum(exp(x - c))]
-        This means we can set c = max(x) to make sure
-        exp(x - c) always is exp(x - max(x)).
-        This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
+    Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
+    Pi = exp(xi) / sum(exp(xi))
+    CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
+         = -y [ x - log[sum(exp(x))] ]
+         = y * (log[sum(exp(x))] - x)
+    If y == 0: CE_i = 0
+    If y == 1: CE_i = logsumexp - x
+
+    logsumexp is also stable
+    Take    y =         log[sum(exp(x))]
+       exp(y) =             sum(exp(x))
+       exp(y) =             sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
+       exp(y) =      exp(c)*sum(exp(x - c))
+           y  = log(exp(c)*sum(exp(x - c)))
+           y  = c + log[sum(exp(x - c))]
+    This means we can set c = max(x) to make sure
+    exp(x - c) always is exp(x - max(x)).
+    This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
     """
     row_idx = tl.program_id(0)
-    logits_ptr    += row_idx * triton_cast(logits_row_stride, tl.int64)
-    loss_ptr      += row_idx
+    logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
+    loss_ptr += row_idx
     logsumexp_ptr += row_idx
-    labels_ptr    += row_idx
+    labels_ptr += row_idx
 
     col_offsets = tl.arange(0, BLOCK_SIZE)
     mask = col_offsets < VOCAB_SIZE
 
     label_idx = tl.load(labels_ptr).to(tl.int32)
-    logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
+    logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(
+        tl.float32
+    )
 
     # Go logit scaling for Cohere: t * x
-    if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
+    if DO_LOGIT_SCALING:
+        logits = LOGIT_SCALE * logits
     # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
-    if DO_SOFTCAPPING:   logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
-    
+    if DO_SOFTCAPPING:
+        logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
+
     c = tl.max(logits, 0)
     logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
 
     if label_idx != -100:
         x = tl.load(logits_ptr + label_idx).to(tl.float32)
         # Go logit scaling for Cohere: t * x
-        if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
+        if DO_LOGIT_SCALING:
+            x = LOGIT_SCALE * x
         # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
-        if DO_SOFTCAPPING:   x = SOFTCAP * triton_tanh(x / SOFTCAP)
+        if DO_SOFTCAPPING:
+            x = SOFTCAP * triton_tanh(x / SOFTCAP)
         loss = logsumexp - x
     else:
         loss = 0.0
     tl.store(logsumexp_ptr, logsumexp)
     tl.store(loss_ptr, loss)
-pass
+
+
 _cross_entropy_forward = triton.jit(_cross_entropy_forward)
 _cross_entropy_forward = triton.heuristics(
     {
-        "DO_SOFTCAPPING":   lambda args: bool(args["DO_SOFTCAPPING"  ]),
+        "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING"]),
         "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
     }
 )(_cross_entropy_forward)
 
 
 def _chunked_cross_entropy_forward(
-    logits_ptr        ,
-    logits_row_stride : tl.constexpr,
-    loss_ptr          ,
-    logsumexp_ptr     ,
-    labels_ptr        ,
-    VOCAB_SIZE        : tl.constexpr,
-    N_CHUNKS          : tl.constexpr,
-    BLOCK_SIZE        : tl.constexpr,
-    DO_SOFTCAPPING    : tl.constexpr,
-    SOFTCAP           : tl.constexpr,
-    DO_LOGIT_SCALING  : tl.constexpr,
-    LOGIT_SCALE       : tl.constexpr,
+    logits_ptr,
+    logits_row_stride: tl.constexpr,
+    loss_ptr,
+    logsumexp_ptr,
+    labels_ptr,
+    VOCAB_SIZE: tl.constexpr,
+    N_CHUNKS: tl.constexpr,
+    BLOCK_SIZE: tl.constexpr,
+    DO_SOFTCAPPING: tl.constexpr,
+    SOFTCAP: tl.constexpr,
+    DO_LOGIT_SCALING: tl.constexpr,
+    LOGIT_SCALE: tl.constexpr,
 ):
     """
-        256K vocab divided in 4 chunks
+    256K vocab divided in 4 chunks
 
-        |-65536-| |-65536-| |-65536-| |-65536-|
-        |-------| |-------| |-------| |-------|
-        |-------| |-------| |-------| |-------|
+    |-65536-| |-65536-| |-65536-| |-65536-|
+    |-------| |-------| |-------| |-------|
+    |-------| |-------| |-------| |-------|
 
-        If y == 0: CE_i = 0
-        If y == 1: CE_i = logsumexp - x
+    If y == 0: CE_i = 0
+    If y == 1: CE_i = logsumexp - x
 
-        Notice we can do logsumexp for each chunk and then
-        logsumexp[chunk_sum(logsumexp)] == logsumexp
+    Notice we can do logsumexp for each chunk and then
+    logsumexp[chunk_sum(logsumexp)] == logsumexp
 
-        chunk_sum = log[chunk_sum(logsumexp)]
-                  = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
-                  = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
-                  = log[sum(exp(a)) + ... + sum(exp(z))]
-                  = logsumexp(x)
+    chunk_sum = log[chunk_sum(logsumexp)]
+              = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
+              = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
+              = log[sum(exp(a)) + ... + sum(exp(z))]
+              = logsumexp(x)
 
-        This means we can perform a logsumexp for each chunk, then do a
-        final logsumexp reduction!
+    This means we can perform a logsumexp for each chunk, then do a
+    final logsumexp reduction!
 
-        Ie do: logsumexp(chunked_logsumexp) - x
+    Ie do: logsumexp(chunked_logsumexp) - x
     """
-    row_idx   = tl.program_id(0)
+    row_idx = tl.program_id(0)
     chunk_idx = tl.program_id(1)
-    logits_ptr    += row_idx * triton_cast(logits_row_stride, tl.int64)
-    loss_ptr      += row_idx
+    logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
+    loss_ptr += row_idx
     logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
-    labels_ptr    += row_idx
+    labels_ptr += row_idx
 
-    col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    col_offsets = chunk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     mask = col_offsets < VOCAB_SIZE
 
     label_idx = tl.load(labels_ptr).to(tl.int32)
-    logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
+    logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(
+        tl.float32
+    )
 
     # Go logit scaling for Cohere: t * x
-    if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
+    if DO_LOGIT_SCALING:
+        logits = LOGIT_SCALE * logits
     # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
-    if DO_SOFTCAPPING:   logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
+    if DO_SOFTCAPPING:
+        logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
 
     c = tl.max(logits, 0)
     logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
@@ -171,60 +182,62 @@ def _chunked_cross_entropy_forward(
         if label_idx != -100:
             x = tl.load(logits_ptr + label_idx).to(tl.float32)
             # Go logit scaling for Cohere: t * x
-            if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
+            if DO_LOGIT_SCALING:
+                x = LOGIT_SCALE * x
             # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
-            if DO_SOFTCAPPING:   x = SOFTCAP * triton_tanh(x / SOFTCAP)
+            if DO_SOFTCAPPING:
+                x = SOFTCAP * triton_tanh(x / SOFTCAP)
             loss = -1.0 * x
         else:
             loss = 0.0
         tl.store(loss_ptr, loss)
-    pass
     tl.store(logsumexp_ptr, logsumexp)
-pass
+
+
 _chunked_cross_entropy_forward = triton.jit(_chunked_cross_entropy_forward)
 _chunked_cross_entropy_forward = triton.heuristics(
     {
-        "DO_SOFTCAPPING":   lambda args: bool(args["DO_SOFTCAPPING"  ]),
+        "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING"]),
         "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
     }
 )(_chunked_cross_entropy_forward)
 
 
 def _cross_entropy_backward(
-    logits_ptr        ,
-    logits_row_stride : tl.constexpr,
-    dloss_ptr         ,
-    dloss_row_stride  : tl.constexpr,
-    logsumexp_ptr     ,
-    labels_ptr        ,
-    VOCAB_SIZE        : tl.constexpr,
-    BLOCK_SIZE        : tl.constexpr,
-    DO_SOFTCAPPING    : tl.constexpr,
-    SOFTCAP           : tl.constexpr,
-    DO_LOGIT_SCALING  : tl.constexpr,
-    LOGIT_SCALE       : tl.constexpr,
+    logits_ptr,
+    logits_row_stride: tl.constexpr,
+    dloss_ptr,
+    dloss_row_stride: tl.constexpr,
+    logsumexp_ptr,
+    labels_ptr,
+    VOCAB_SIZE: tl.constexpr,
+    BLOCK_SIZE: tl.constexpr,
+    DO_SOFTCAPPING: tl.constexpr,
+    SOFTCAP: tl.constexpr,
+    DO_LOGIT_SCALING: tl.constexpr,
+    LOGIT_SCALE: tl.constexpr,
 ):
     """
-        CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
-        dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
+    CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
+    dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
 
-        From https://en.wikipedia.org/wiki/LogSumExp
-        d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
+    From https://en.wikipedia.org/wiki/LogSumExp
+    d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
 
-        dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
-        dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
-        dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
+    dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
+    dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
+    dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
 
-        If y == 0: dC/dx = 0
-        If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
-        If y == 1 and x != label: dC/dx     = exp[x - logsumexp]
+    If y == 0: dC/dx = 0
+    If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
+    If y == 1 and x != label: dC/dx     = exp[x - logsumexp]
     """
-    row_idx   = tl.program_id(0)
+    row_idx = tl.program_id(0)
     block_idx = tl.program_id(1)
 
     logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
-    dloss_ptr  += row_idx *  dloss_row_stride
-    col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    dloss_ptr += row_idx * dloss_row_stride
+    col_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     mask = col_offsets < VOCAB_SIZE
     label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
 
@@ -239,7 +252,6 @@ def _cross_entropy_backward(
     if DO_LOGIT_SCALING:
         # d/dx [s * x] = s
         x = x * LOGIT_SCALE
-    pass
 
     # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
     partial = x
@@ -247,141 +259,165 @@ def _cross_entropy_backward(
         # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
         partial = triton_tanh(x / SOFTCAP)
         x = SOFTCAP * partial
-    pass
 
     logsumexp = tl.load(logsumexp_ptr + row_idx)
     y = tl.exp(x - logsumexp)
     y = tl.where(
         col_offsets == label_idx,
-        y - 1.0, # exp(x - logsumexp) - 1
-        y,       # exp(x - logsumexp)
+        y - 1.0,  # exp(x - logsumexp) - 1
+        y,  # exp(x - logsumexp)
     )
 
     if DO_LOGIT_SCALING:
         # d/dx [s * x] = s
         y = y * LOGIT_SCALE
-    pass
 
     if DO_SOFTCAPPING:
         # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
-        y = y * (1.0 - partial*partial)
-    pass
+        y = y * (1.0 - partial * partial)
 
     # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
     tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
-pass
+
+
 _cross_entropy_backward = triton.jit(_cross_entropy_backward)
 _cross_entropy_backward = triton.heuristics(
     {
-        "DO_SOFTCAPPING":   lambda args: bool(args["DO_SOFTCAPPING"  ]),
+        "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING"]),
         "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
     }
 )(_cross_entropy_backward)
 
 
-MAX_FUSED_SIZE = 65536 # 2**16
+MAX_FUSED_SIZE = 65536  # 2**16
+
+
 class Fast_CrossEntropyLoss(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0):
-        n_rows : int
-        vocab_size : int
+    def forward(
+        ctx, logits, labels, logit_softcapping: float = 0, logit_scaling: float = 0
+    ):
+        n_rows: int
+        vocab_size: int
         n_rows, vocab_size = logits.shape
         device = logits.device
 
         div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
-        n_chunks : int = div + (mod != 0)
+        n_chunks: int = div + (mod != 0)
         losses = torch.empty(n_rows, dtype = torch.float32, device = device)
 
-        DO_SOFTCAPPING   : bool = bool(logit_softcapping != 0)
-        DO_LOGIT_SCALING : bool = bool(logit_scaling != 0)
+        DO_SOFTCAPPING: bool = bool(logit_softcapping != 0)
+        DO_LOGIT_SCALING: bool = bool(logit_scaling != 0)
 
-        BLOCK_SIZE : int
-        num_warps  : int
+        BLOCK_SIZE: int
+        num_warps: int
         if n_chunks == 1:
             # For small vocabs <= 65336 like Llama, Mistral
             BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
-            if is_cdna(): num_warps = num_warps // 2
+            if is_cdna():
+                num_warps = num_warps // 2
             logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device)
 
             with torch_gpu_device(device):
                 _cross_entropy_forward[(n_rows,)](
-                    logits, logits.stride(0),
+                    logits,
+                    logits.stride(0),
                     losses,
                     logsumexp,
                     labels,
-                    VOCAB_SIZE       = vocab_size,
-                    BLOCK_SIZE       = BLOCK_SIZE,
-                    DO_SOFTCAPPING   = DO_SOFTCAPPING,
-                    SOFTCAP          = logit_softcapping,
+                    VOCAB_SIZE = vocab_size,
+                    BLOCK_SIZE = BLOCK_SIZE,
+                    DO_SOFTCAPPING = DO_SOFTCAPPING,
+                    SOFTCAP = logit_softcapping,
                     DO_LOGIT_SCALING = DO_LOGIT_SCALING,
-                    LOGIT_SCALE      = logit_scaling,
-                    num_warps        = num_warps,
+                    LOGIT_SCALE = logit_scaling,
+                    num_warps = num_warps,
                 )
         else:
             # For large vocabs > 65336 like Gemma 256K
-            logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = device)
+            logsumexp = torch.empty(
+                (
+                    n_rows,
+                    n_chunks,
+                ),
+                dtype = torch.float32,
+                device = device,
+            )
 
             with torch_gpu_device(device):
-                _chunked_cross_entropy_forward[(n_rows, n_chunks,)](
-                    logits, logits.stride(0),
+                _chunked_cross_entropy_forward[
+                    (
+                        n_rows,
+                        n_chunks,
+                    )
+                ](
+                    logits,
+                    logits.stride(0),
                     losses,
                     logsumexp,
                     labels,
-                    VOCAB_SIZE       = vocab_size,
-                    N_CHUNKS         = n_chunks,
-                    BLOCK_SIZE       = MAX_FUSED_SIZE,
-                    DO_SOFTCAPPING   = DO_SOFTCAPPING,
-                    SOFTCAP          = logit_softcapping,
+                    VOCAB_SIZE = vocab_size,
+                    N_CHUNKS = n_chunks,
+                    BLOCK_SIZE = MAX_FUSED_SIZE,
+                    DO_SOFTCAPPING = DO_SOFTCAPPING,
+                    SOFTCAP = logit_softcapping,
                     DO_LOGIT_SCALING = DO_LOGIT_SCALING,
-                    LOGIT_SCALE      = logit_scaling,
-                    num_warps        = 32 if not is_cdna() else 16,
+                    LOGIT_SCALE = logit_scaling,
+                    num_warps = 32 if not is_cdna() else 16,
                 )
             # logsumexp(chunked_logsumexp) - x
             # Do the -x separately
-            logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
+            logsumexp = torch.logsumexp(logsumexp, dim = 1)  # Row sum
             losses += logsumexp
-            losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
-        pass
+            losses.masked_fill_(labels == -100, 0)  # Don't forget to mask padding out!
 
         ctx.save_for_backward(logits, logsumexp, labels)
-        ctx.DO_SOFTCAPPING    = DO_SOFTCAPPING
+        ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
         ctx.logit_softcapping = logit_softcapping
-        ctx.DO_LOGIT_SCALING  = DO_LOGIT_SCALING
-        ctx.logit_scaling     = logit_scaling
+        ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
+        ctx.logit_scaling = logit_scaling
         return losses
-    pass
-
 
     @staticmethod
     def backward(ctx, dlosses):
         logits, logsumexp, labels = ctx.saved_tensors
-        n_rows : int
-        vocab_size : int
+        n_rows: int
+        vocab_size: int
         n_rows, vocab_size = logits.shape
 
-        BLOCK_SIZE : int = 4096
-        div : int
-        mod : int
+        BLOCK_SIZE: int = 4096
+        div: int
+        mod: int
         div, mod = divmod(vocab_size, BLOCK_SIZE)
-        n_blocks : int = div + (mod != 0)
+        n_blocks: int = div + (mod != 0)
 
         with torch_gpu_device(dlosses.device):
-            _cross_entropy_backward[(n_rows, n_blocks,)](
-                logits,   logits.stride(0),
-                dlosses, dlosses.stride(0),
+            _cross_entropy_backward[
+                (
+                    n_rows,
+                    n_blocks,
+                )
+            ](
+                logits,
+                logits.stride(0),
+                dlosses,
+                dlosses.stride(0),
                 logsumexp,
                 labels,
-                VOCAB_SIZE       = vocab_size,
-                BLOCK_SIZE       = BLOCK_SIZE,
-                DO_SOFTCAPPING   = ctx.DO_SOFTCAPPING,
-                SOFTCAP          = ctx.logit_softcapping,
+                VOCAB_SIZE = vocab_size,
+                BLOCK_SIZE = BLOCK_SIZE,
+                DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
+                SOFTCAP = ctx.logit_softcapping,
                 DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
-                LOGIT_SCALE      = ctx.logit_scaling,
-                num_warps        = 8,
+                LOGIT_SCALE = ctx.logit_scaling,
+                num_warps = 8,
             )
-        return logits, None, None, None,
-    pass
-pass
+        return (
+            logits,
+            None,
+            None,
+            None,
+        )
 
 
 def fast_cross_entropy_loss(
@@ -399,10 +435,10 @@ def fast_cross_entropy_loss(
         losses: float
     """
     batch, seq_len, d = logits.shape
-    assert(labels.shape == (batch, seq_len))
+    assert labels.shape == (batch, seq_len)
 
     loss = Fast_CrossEntropyLoss.apply(
-        logits.view(batch*seq_len, d),
+        logits.view(batch * seq_len, d),
         labels.view(-1),
         logit_softcapping,
         logit_scaling,
@@ -410,13 +446,14 @@ def fast_cross_entropy_loss(
     if n_items is None:
         n_items = torch.count_nonzero(labels != -100)
     return loss.sum() / n_items
-pass
-if (Version(torch.__version__) < Version("2.4.0")) and \
-    not hasattr(fast_cross_entropy_loss, "__wrapped__"):
+
+
+if (Version(torch.__version__) < Version("2.4.0")) and not hasattr(
+    fast_cross_entropy_loss, "__wrapped__"
+):
     fast_cross_entropy_loss = torch._disable_dynamo(fast_cross_entropy_loss)
-pass
+
 
 # Patch CE Losses in transformers
 def patch_loss_functions(torch_compile = True):
     _patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)
-pass
diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py
index d0d8627f0..f3e0b7f7a 100644
--- a/unsloth/kernels/fast_lora.py
+++ b/unsloth/kernels/fast_lora.py
@@ -63,54 +63,95 @@ class LoRA_MLP(torch.autograd.Function):
 
     Don't forget to see our blog post for more details!
     """
+
     @staticmethod
     @torch_amp_custom_fwd
-    def forward(ctx, X : torch.Tensor,
-                gateW, gateW_quant, gateA, gateB, gateS,
-                  upW,   upW_quant, upA,   upB,   upS,
-                downW, downW_quant, downA, downB, downS,
-                _forward_function, _backward_function,
-                inplace = True,):
+    def forward(
+        ctx,
+        X: torch.Tensor,
+        gateW,
+        gateW_quant,
+        gateA,
+        gateB,
+        gateS,
+        upW,
+        upW_quant,
+        upA,
+        upB,
+        upS,
+        downW,
+        downW_quant,
+        downA,
+        downB,
+        downS,
+        _forward_function,
+        _backward_function,
+        inplace = True,
+    ):
         dtype = X.dtype
 
         e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
-        g = matmul_lora(X,   upW,   upW_quant,   upA,   upB,   upS)
+        g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
         h = _forward_function(e, g)
         i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
 
         ctx.custom_saved_tensors = (
-            gateW, gateW_quant, gateS,
-            upW, upW_quant, upS,
-            downW, downW_quant, downS,
+            gateW,
+            gateW_quant,
+            gateS,
+            upW,
+            upW_quant,
+            upS,
+            downW,
+            downW_quant,
+            downS,
             _backward_function,
         )
-        ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
-                              X, e, g)
+        ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB, X, e, g)
         ctx.inplace = inplace
         return i
-    pass
-
 
     @staticmethod
     @torch_amp_custom_bwd
-    def backward(ctx, dY : torch.Tensor):
-        gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
-            _backward_function = ctx.custom_saved_tensors
-        gateA, gateB, upA, upB, downA, downB, \
-            X, e, g = ctx.saved_tensors
+    def backward(ctx, dY: torch.Tensor):
+        (
+            gateW,
+            gateW_quant,
+            gateS,
+            upW,
+            upW_quant,
+            upS,
+            downW,
+            downW_quant,
+            downS,
+            _backward_function,
+        ) = ctx.custom_saved_tensors
+        gateA, gateB, upA, upB, downA, downB, X, e, g = ctx.saved_tensors
 
         batch, seq_len, hd = X.shape
         dY = dY.view(-1, dY.shape[-1])
-        X  = X .view(-1, X .shape[-1])
-        e  = e .view(-1, e .shape[-1])
-        g  = g .view(-1, g .shape[-1])
+        X = X.view(-1, X.shape[-1])
+        e = e.view(-1, e.shape[-1])
+        g = g.view(-1, g.shape[-1])
         dtype = X.dtype
 
-        gateA, gateB, upA, upB, downA, downB = \
-            gateA.to(dtype), gateB.to(dtype), upA.to(dtype), upB.to(dtype), downA.to(dtype), downB.to(dtype)
+        gateA, gateB, upA, upB, downA, downB = (
+            gateA.to(dtype),
+            gateB.to(dtype),
+            upA.to(dtype),
+            upB.to(dtype),
+            downA.to(dtype),
+            downB.to(dtype),
+        )
 
-        gateA, gateB, upA, upB, downA, downB = \
-            gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
+        gateA, gateB, upA, upB, downA, downB = (
+            gateA.t(),
+            gateB.t(),
+            upA.t(),
+            upB.t(),
+            downA.t(),
+            downB.t(),
+        )
 
         DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
         DW, e, g = _backward_function(DW, e, g)
@@ -120,8 +161,8 @@ def backward(ctx, dY : torch.Tensor):
         d_downB = torch.empty_like(downB)
         d_gateA = torch.empty_like(gateA)
         d_gateB = torch.empty_like(gateB)
-        d_upA   = torch.empty_like(upA)
-        d_upB   = torch.empty_like(upB)
+        d_upA = torch.empty_like(upA)
+        d_upB = torch.empty_like(upB)
 
         # Down projection LoRA weights
         # d_downA = h.t() @ (dY @ downB.t())
@@ -165,60 +206,122 @@ def backward(ctx, dY : torch.Tensor):
         # gateW, gateW_quant, gateA, gateB, gateS,
         #  upW,    upW_quant,   upA,   upB,   upS,
         # downW, downW_quant, downA, downB, downS,
-        return dX.view(batch, seq_len, hd), \
-            None, None, d_gateA.t(), d_gateB.t(), None, \
-            None, None,   d_upA.t(),   d_upB.t(), None, \
-            None, None, d_downA.t(), d_downB.t(), None, \
-            None, None, None, # _backward and _forward and inplace
-    pass
-pass
+        return (
+            dX.view(batch, seq_len, hd),
+            None,
+            None,
+            d_gateA.t(),
+            d_gateB.t(),
+            None,
+            None,
+            None,
+            d_upA.t(),
+            d_upB.t(),
+            None,
+            None,
+            None,
+            d_downA.t(),
+            d_downB.t(),
+            None,
+            None,
+            None,
+            None,
+        )  # _backward and _forward and inplace
 
 
 from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
+
+
 def apply_lora_mlp_swiglu(self, X, inplace = True):
     X = _maybe_fake_quantize_activations(X, self.gate_proj)
     gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
-    upW,     upW_quant,   upA,   upB,   upS = get_lora_parameters(self.  up_proj)
+    upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
     downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
-    out = LoRA_MLP.apply(X,
-                         gateW, gateW_quant, gateA, gateB, gateS,
-                         upW,     upW_quant, upA,   upB,   upS,
-                         downW, downW_quant, downA, downB, downS,
-                         swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,
-                         inplace,)
+    out = LoRA_MLP.apply(
+        X,
+        gateW,
+        gateW_quant,
+        gateA,
+        gateB,
+        gateS,
+        upW,
+        upW_quant,
+        upA,
+        upB,
+        upS,
+        downW,
+        downW_quant,
+        downA,
+        downB,
+        downS,
+        swiglu_fg_kernel,
+        swiglu_DWf_DW_dfg_kernel,
+        inplace,
+    )
     return out
-pass
 
 
 from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
+
+
 def apply_lora_mlp_geglu_exact(self, X, inplace = True):
     X = _maybe_fake_quantize_activations(X, self.gate_proj)
     gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
-    upW,     upW_quant,   upA,   upB,   upS = get_lora_parameters(self.  up_proj)
+    upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
     downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
-    out = LoRA_MLP.apply(X,
-                         gateW, gateW_quant, gateA, gateB, gateS,
-                         upW,     upW_quant, upA,   upB,   upS,
-                         downW, downW_quant, downA, downB, downS,
-                         geglu_exact_forward_kernel, geglu_exact_backward_kernel,
-                         inplace,)
+    out = LoRA_MLP.apply(
+        X,
+        gateW,
+        gateW_quant,
+        gateA,
+        gateB,
+        gateS,
+        upW,
+        upW_quant,
+        upA,
+        upB,
+        upS,
+        downW,
+        downW_quant,
+        downA,
+        downB,
+        downS,
+        geglu_exact_forward_kernel,
+        geglu_exact_backward_kernel,
+        inplace,
+    )
     return out
-pass
 
 
 from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
+
+
 def apply_lora_mlp_geglu_approx(self, X):
     X = _maybe_fake_quantize_activations(X, self.gate_proj)
     gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
-    upW,     upW_quant,   upA,   upB,   upS = get_lora_parameters(self.  up_proj)
+    upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
     downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
-    out = LoRA_MLP.apply(X,
-                         gateW, gateW_quant, gateA, gateB, gateS,
-                         upW,     upW_quant, upA,   upB,   upS,
-                         downW, downW_quant, downA, downB, downS,
-                         geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
+    out = LoRA_MLP.apply(
+        X,
+        gateW,
+        gateW_quant,
+        gateA,
+        gateB,
+        gateS,
+        upW,
+        upW_quant,
+        upA,
+        upB,
+        upS,
+        downW,
+        downW_quant,
+        downA,
+        downB,
+        downS,
+        geglu_approx_forward_kernel,
+        geglu_approx_backward_kernel,
+    )
     return out
-pass
 
 
 class LoRA_QKV(torch.autograd.Function):
@@ -251,13 +354,29 @@ class LoRA_QKV(torch.autograd.Function):
     dC/dAv =       X.T @ D(Wv) @ B.T
     dC/dBv = A.T @ X.T @ D(Wv)
     """
+
     @staticmethod
     @torch_amp_custom_fwd
-    def forward(ctx, X : torch.Tensor,
-                QW, QW_quant, QA, QB, QS,
-                KW, KW_quant, KA, KB, KS,
-                VW, VW_quant, VA, VB, VS,
-                inplace = True):
+    def forward(
+        ctx,
+        X: torch.Tensor,
+        QW,
+        QW_quant,
+        QA,
+        QB,
+        QS,
+        KW,
+        KW_quant,
+        KA,
+        KB,
+        KS,
+        VW,
+        VW_quant,
+        VA,
+        VB,
+        VS,
+        inplace = True,
+    ):
         dtype = X.dtype
 
         Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
@@ -265,34 +384,59 @@ def forward(ctx, X : torch.Tensor,
         V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
 
         ctx.custom_saved_tensors = (
-            QW, QW_quant, QS,
-            KW, KW_quant, KS,
-            VW, VW_quant, VS,
+            QW,
+            QW_quant,
+            QS,
+            KW,
+            KW_quant,
+            KS,
+            VW,
+            VW_quant,
+            VS,
+        )
+        ctx.save_for_backward(
+            X,
+            QA,
+            QB,
+            KA,
+            KB,
+            VA,
+            VB,
         )
-        ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
         ctx.inplace = inplace
         return Q, K, V
-    pass
 
     @staticmethod
     @torch_amp_custom_bwd
     def backward(ctx, dQ, dK, dV):
-        QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
-            ctx.custom_saved_tensors
-        X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
+        QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = ctx.custom_saved_tensors
+        (
+            X,
+            QA,
+            QB,
+            KA,
+            KB,
+            VA,
+            VB,
+        ) = ctx.saved_tensors
 
         batch, seq_len, hd = X.shape
         dQ = dQ.view(-1, dQ.shape[-1])
-        dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
+        dK = dK.reshape(-1, dK.shape[-1])  # view doesn't work on K.T
         dV = dV.view(-1, dV.shape[-1])
-        X  = X .view(-1, X .shape[-1])
+        X = X.view(-1, X.shape[-1])
         dtype = X.dtype
 
-        QA, QB, KA, KB, VA, VB = \
-            QA.to(dtype), QB.to(dtype), KA.to(dtype), KB.to(dtype), VA.to(dtype), VB.to(dtype)
+        QA, QB, KA, KB, VA, VB = (
+            QA.to(dtype),
+            QB.to(dtype),
+            KA.to(dtype),
+            KB.to(dtype),
+            VA.to(dtype),
+            VB.to(dtype),
+        )
 
-        QA, QB, KA, KB, VA, VB = \
-            QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
+        QA, QB, KA, KB, VA, VB = QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
 
         ### Weight projection LoRA weights
         # See our blogpost for more details.
@@ -354,13 +498,25 @@ def backward(ctx, dQ, dK, dV):
         # QW, QW_quant, QA, QB, QS,
         # KW, KW_quant, KA, KB, KS,
         # VW, VW_quant, VA, VB, VS,
-        return dX.view(batch, seq_len, hd), \
-            None, None, d_QA.t(), d_QB.t(), None, \
-            None, None, d_KA.t(), d_KB.t(), None, \
-            None, None, d_VA.t(), d_VB.t(), None, \
+        return (
+            dX.view(batch, seq_len, hd),
+            None,
+            None,
+            d_QA.t(),
+            d_QB.t(),
+            None,
+            None,
+            None,
+            d_KA.t(),
+            d_KB.t(),
             None,
-    pass
-pass
+            None,
+            None,
+            d_VA.t(),
+            d_VB.t(),
+            None,
+            None,
+        )
 
 
 def apply_lora_qkv(self, X, inplace = True):
@@ -368,14 +524,26 @@ def apply_lora_qkv(self, X, inplace = True):
     QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
     KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
     VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
-    Q, K, V = LoRA_QKV.apply(X,
-        QW, QW_quant, QA, QB, QS,
-        KW, KW_quant, KA, KB, KS,
-        VW, VW_quant, VA, VB, VS,
+    Q, K, V = LoRA_QKV.apply(
+        X,
+        QW,
+        QW_quant,
+        QA,
+        QB,
+        QS,
+        KW,
+        KW_quant,
+        KA,
+        KB,
+        KS,
+        VW,
+        VW_quant,
+        VA,
+        VB,
+        VS,
         inplace,
     )
     return Q, K, V
-pass
 
 
 class LoRA_W(torch.autograd.Function):
@@ -405,26 +573,29 @@ class LoRA_W(torch.autograd.Function):
     dC/dAv =       X.T @ D(Wv) @ B.T
     dC/dBv = A.T @ X.T @ D(Wv)
     """
+
     @staticmethod
     @torch_amp_custom_fwd
-    def forward(ctx, X : torch.Tensor,
-                W, W_quant, A, B, S):
+    def forward(ctx, X: torch.Tensor, W, W_quant, A, B, S):
         dtype = X.dtype
         XW = matmul_lora(X, W, W_quant, A, B, S)
-        ctx.custom_saved_tensors = (W, W_quant, S,)
+        ctx.custom_saved_tensors = (
+            W,
+            W_quant,
+            S,
+        )
         ctx.save_for_backward(A, B, X)
         return XW
-    pass
 
     @staticmethod
     @torch_amp_custom_bwd
-    def backward(ctx, dY : torch.Tensor):
+    def backward(ctx, dY: torch.Tensor):
         W, W_quant, S = ctx.custom_saved_tensors
         A, B, X = ctx.saved_tensors
 
         batch, seq_len, hd = X.shape
-        dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
-        X  = X .reshape(-1, X .shape[-1]) # Must be reshape
+        dY = dY.reshape(-1, dY.shape[-1])  # Must be reshape
+        X = X.reshape(-1, X.shape[-1])  # Must be reshape
         dtype = X.dtype
 
         A, B = A.to(dtype), B.to(dtype)
@@ -451,10 +622,7 @@ def backward(ctx, dY : torch.Tensor):
         dX.addmm_(dY @ B.t(), A.t(), alpha = S)
 
         # W, W_quant, A, B, S
-        return dX.view(batch, seq_len, hd), \
-            None, None, d_A.t(), d_B.t(), None
-    pass
-pass
+        return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
 
 
 def apply_lora_o(self, X):
@@ -462,10 +630,11 @@ def apply_lora_o(self, X):
     OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
     O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
     return O
-pass
 
 
 IDENTITY_DROPOUT = torch.nn.Identity
+
+
 @torch._disable_dynamo
 def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
     raise NotImplementedError(
@@ -479,24 +648,28 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
             self.unmerge()
         result = self.base_layer(x, *args, **kwargs)
     elif adapter_names is not None:
-        result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
+        result = self._mixed_batch_forward(
+            x, *args, adapter_names = adapter_names, **kwargs
+        )
     elif self.merged:
         result = self.base_layer(x, *args, **kwargs)
     else:
         # Fastpath
         if len(self.active_adapters) == 1:
             active_adapter = self.active_adapters[0]
-            if active_adapter not in self.lora_A.keys(): return self.base_layer(x, *args, **kwargs)
+            if active_adapter not in self.lora_A.keys():
+                return self.base_layer(x, *args, **kwargs)
 
             dropout = self.lora_dropout[active_adapter]
-            if isinstance(dropout, IDENTITY_DROPOUT) and not self.use_dora[active_adapter]:
+            if (
+                isinstance(dropout, IDENTITY_DROPOUT)
+                and not self.use_dora[active_adapter]
+            ):
                 lora_A = self.lora_A[active_adapter].weight
                 lora_B = self.lora_B[active_adapter].weight
                 scaling = self.scaling[active_adapter]
                 W = self.base_layer.weight
                 return LoRA_W.apply(x, W, QUANT_STATE(W), lora_A, lora_B, scaling)
-            pass
-        pass
 
         result = self.base_layer(x, *args, **kwargs)
         # As per Tim Dettmers, for 4bit, we need to defensively clone here.
@@ -530,14 +703,13 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
 
                 result = result + self.lora_magnitude_vector[active_adapter](
                     x,
-                    lora_A=lora_A,
-                    lora_B=lora_B,
-                    scaling=scaling,
-                    base_layer=self.get_base_layer(),
-                    base_result=base_result,
+                    lora_A = lora_A,
+                    lora_B = lora_B,
+                    scaling = scaling,
+                    base_layer = self.get_base_layer(),
+                    base_result = base_result,
                 )
             if requires_conversion:
                 result = result.to(expected_dtype)
 
     return result
-pass
diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py
index 6f8239422..b94ff56de 100644
--- a/unsloth/kernels/flex_attention.py
+++ b/unsloth/kernels/flex_attention.py
@@ -18,11 +18,11 @@
 import os
 
 torch_compile_options = {
-    "epilogue_fusion"   : True,
-    "max_autotune"      : True,
-    "shape_padding"     : True,
-    "trace.enabled"     : os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1",
-    "triton.cudagraphs" : False,
+    "epilogue_fusion": True,
+    "max_autotune": True,
+    "shape_padding": True,
+    "trace.enabled": os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1",
+    "triton.cudagraphs": False,
 }
 
 # Flex Attention supported from torch 2.5 onwards only
@@ -31,23 +31,24 @@
         flex_attention as _flex_attention,
         create_block_mask as _create_block_mask,
     )
-    _flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
+
+    _flex_attention = torch.compile(
+        _flex_attention, dynamic = True, options = torch_compile_options
+    )
     HAS_FLEX_ATTENTION = False
 except:
     HAS_FLEX_ATTENTION = False
-pass
 
 
 if not HAS_FLEX_ATTENTION:
-
     # Logit softcapping
     @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
     def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
-        n_heads    = self.config.num_attention_heads
-        head_dim   = self.head_dim
+        n_heads = self.config.num_attention_heads
+        head_dim = self.head_dim
         n_kv_heads = self.config.num_key_value_heads
-        n_groups   = self.num_key_value_groups
-        
+        n_groups = self.num_key_value_groups
+
         # Grouped query attention
         K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
         V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
@@ -61,18 +62,17 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
         s = self.config.query_pre_attn_scalar
         t = self.config.attn_logit_softcapping
 
-        Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
+        Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype)  # Follow Keras exactly
         A = torch.matmul(Q, K.transpose(2, 3))
-        A = t * torch.tanh(A / t) # Logit softcapping
+        A = t * torch.tanh(A / t)  # Logit softcapping
         A += causal_mask[:q_len, :q_len]
         # Much slower in torch compile!
         # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
         A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
         A = torch.matmul(A, V)
         A = A.transpose(1, 2).contiguous()
-        A = A.reshape(bsz, q_len, n_heads*head_dim)
+        A = A.reshape(bsz, q_len, n_heads * head_dim)
         return A
-    pass
 
     create_flex_attention_causal_mask = None
     create_flex_attention_sliding_window_mask = None
@@ -85,73 +85,78 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
     def generate_tanh_softcap(t):
         def tanh_softcap(x, b, h, q_idx, kv_idx):
             return t * torch.tanh(x / t)
+
         return tanh_softcap
-    pass
+
     def causal_masker(b, h, q_idx, kv_idx):
         return q_idx >= kv_idx
-    pass
 
     @functools.lru_cache
     def sliding_window_masker(size = 4096):
         def sliding_window(b, h, q_idx, kv_idx):
             causal_mask = q_idx >= kv_idx
-            window_mask = q_idx - kv_idx <= size 
+            window_mask = q_idx - kv_idx <= size
             return causal_mask & window_mask
+
         return sliding_window
-    pass
 
     @functools.lru_cache
     def create_block_mask(mask, n = 128):
         return _create_block_mask(
-            mask, 1, 1, n, n,
+            mask,
+            1,
+            1,
+            n,
+            n,
             BLOCK_SIZE = 128,
             _compile = True,
         )
-    pass
 
     def create_flex_attention_causal_mask(max_seq_length = 8192):
         causal_mask = create_block_mask(causal_masker, max_seq_length)
         return causal_mask
-    pass
 
-    def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096):
+    def create_flex_attention_sliding_window_mask(
+        max_seq_length = 8192, sliding_window = 4096
+    ):
         sliding_masker = sliding_window_masker(sliding_window)
         causal_mask = create_block_mask(sliding_masker, max_seq_length)
         return causal_mask
-    pass
 
     @functools.lru_cache
     def flex_attention(s, t):
         scale = 1.0 / math.sqrt(s)
         score_mod = generate_tanh_softcap(t)
         return functools.partial(
-            _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True,
+            _flex_attention,
+            score_mod = score_mod,
+            scale = scale,
+            enable_gqa = True,
         )
-    pass
-    
+
     def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
-        n_heads    = self.config.num_attention_heads
-        head_dim   = self.head_dim
+        n_heads = self.config.num_attention_heads
+        head_dim = self.head_dim
         s = self.config.query_pre_attn_scalar
         t = self.config.attn_logit_softcapping
         fx = flex_attention(s, t)
         A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
         A = A.transpose(1, 2).contiguous()
-        A = A.reshape(bsz, q_len, n_heads*head_dim)
+        A = A.reshape(bsz, q_len, n_heads * head_dim)
         return A
-    pass
-pass
 
 
 torch_matmul = torch.matmul
-torch_tanh   = torch.tanh
+torch_tanh = torch.tanh
 torch_nn_functional_softmax = torch.nn.functional.softmax
+
+
 def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
-    n_heads    = self.config.num_attention_heads
-    head_dim   = self.head_dim
+    n_heads = self.config.num_attention_heads
+    head_dim = self.head_dim
     n_kv_heads = self.config.num_key_value_heads
-    n_groups   = self.num_key_value_groups
-    
+    n_groups = self.num_key_value_groups
+
     # Grouped query attention
     K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
     V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
@@ -165,17 +170,18 @@ def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len)
     s = self.config.query_pre_attn_scalar
     t = self.config.attn_logit_softcapping
 
-    Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
+    Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype)  # Follow Keras exactly
     A = torch_matmul(Q, K.transpose(2, 3))
 
     # Logit softcapping
-    A /= t; torch_tanh(A, out = A); A *= t;
+    A /= t
+    torch_tanh(A, out = A)
+    A *= t
     A += causal_mask[:q_len, :q_len]
     # Much slower in torch compile!
     # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
     A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
     A = torch_matmul(A, V)
     A = A.transpose(1, 2).contiguous()
-    A = A.reshape(bsz, q_len, n_heads*head_dim)
+    A = A.reshape(bsz, q_len, n_heads * head_dim)
     return A
-pass
diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py
index 0e893a360..0230da95b 100644
--- a/unsloth/kernels/fp8.py
+++ b/unsloth/kernels/fp8.py
@@ -19,25 +19,34 @@
 import math
 from unsloth_zoo.log import logger
 from unsloth_zoo.temporary_patches.common import torch_compile
+
 torch_matmul = torch.matmul
 
 try:
     from transformers.integrations.finegrained_fp8 import FP8Linear
 except:
     FP8Linear = None
-    logger.info("Unsloth: FP8 models need importing FP8Linear from `transformers.integrations.finegrained_fp8` but we don't see it.")
+    logger.info(
+        "Unsloth: FP8 models need importing FP8Linear from `transformers.integrations.finegrained_fp8` but we don't see it."
+    )
 
 try:
     from transformers.integrations.fbgemm_fp8 import FbgemmFp8Linear
 except:
     FbgemmFp8Linear = None
-    logger.info("Unsloth: FP8 models need importing FbgemmFP8Linear from `transformers.integrations.fbgemm_fp8` but we don't see it.")
+    logger.info(
+        "Unsloth: FP8 models need importing FbgemmFP8Linear from `transformers.integrations.fbgemm_fp8` but we don't see it."
+    )
 
 try:
-    from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_block
+    from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
+        triton_quantize_fp8_block,
+    )
 except:
     triton_quantize_fp8_block = None
-    logger.info("Unsloth: Could not find fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm.triton_quantize_fp8_block")
+    logger.info(
+        "Unsloth: Could not find fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm.triton_quantize_fp8_block"
+    )
 
 try:
     from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
@@ -45,26 +54,29 @@
     )
 except:
     torchao_blockwise_gemm = None
-    logger.info("Unsloth: Could not find torchao.prototype.blockwise_fp8_inference.blockwise_quantization.blockwise_fp8_gemm")
+    logger.info(
+        "Unsloth: Could not find torchao.prototype.blockwise_fp8_inference.blockwise_quantization.blockwise_fp8_gemm"
+    )
 
 
 @triton.jit
 def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
-    pid_m = tl.program_id(axis=0)
-    pid_n = tl.program_id(axis=1)
+    pid_m = tl.program_id(axis = 0)
+    pid_n = tl.program_id(axis = 1)
     n = tl.cdiv(N, BLOCK_SIZE)
     offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     offs = offs_m[:, None] * N + offs_n[None, :]
     mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
-    x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
+    x = tl.load(x_ptr + offs, mask = mask).to(tl.float32)
     s = tl.load(s_ptr + pid_m * n + pid_n)
     y = x * s
-    tl.store(y_ptr + offs, y, mask=mask)
-pass
+    tl.store(y_ptr + offs, y, mask = mask)
 
 
-def weight_dequant_block(x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype=torch.bfloat16) -> torch.Tensor:
+def weight_dequant_block(
+    x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype = torch.bfloat16
+) -> torch.Tensor:
     if not x.is_contiguous():
         x = x.contiguous()
     if not s.is_contiguous():
@@ -72,12 +84,15 @@ def weight_dequant_block(x: torch.Tensor, s: torch.Tensor, block_size: int = 128
     assert x.dim() == 2 and s.dim() == 2
     M, N = x.size()
     y = torch.empty_like(x, dtype = dtype)
-    grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
-    weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
+    grid = lambda meta: (
+        triton.cdiv(M, meta["BLOCK_SIZE"]),
+        triton.cdiv(N, meta["BLOCK_SIZE"]),
+    )
+    weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE = block_size)
     return y
-pass
 
-def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16):
+
+def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype = torch.bfloat16):
     if s.shape[1] == 1:
         # this is row quantized weight, just simple multiplication suffices
         if x.shape[0] == s.shape[0]:
@@ -87,17 +102,17 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype=torch.bfloat16):
             y = x.t().to(dtype) * s.to(dtype)
             y = y.t()
         else:
-            raise ValueError(f'Incompatible shapes {x.shape=}, {s.shape=}')
+            raise ValueError(f"Incompatible shapes {x.shape = }, {s.shape = }")
         return y
     else:
         # this is block quantized weight
-        return weight_dequant_block(x, s, dtype=dtype)
-pass
+        return weight_dequant_block(x, s, dtype = dtype)
+
 
 # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
 @triton.jit
 def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
-    pid = tl.program_id(axis=0)
+    pid = tl.program_id(axis = 0)
     offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     x = tl.load(x_ptr + offs).to(tl.float32)
     s = tl.max(tl.abs(x)) / 448.0
@@ -109,13 +124,15 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
     y = y.to(y_ptr.dtype.element_ty)
     tl.store(y_ptr + offs, y)
     tl.store(s_ptr + pid, s)
-pass
 
-def act_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
+
+def act_quant(
+    x: torch.Tensor, block_size: int = 128
+) -> tuple[torch.Tensor, torch.Tensor]:
     if not x.is_contiguous():
         x = x.contiguous()
     assert x.shape[-1] % block_size == 0
-    y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
+    y = torch.empty_like(x, dtype = torch.float8_e4m3fn)
     s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype = torch.float32)
 
     def grid(meta):
@@ -123,7 +140,7 @@ def grid(meta):
 
     act_quant_kernel[grid](x, y, s, BLOCK_SIZE = block_size)
     return y, s
-pass
+
 
 # Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
 @triton.jit
@@ -163,7 +180,7 @@ def _w8a8_block_fp8_matmul(
     store the result in output tensor `C`.
     """
 
-    pid = tl.program_id(axis=0)
+    pid = tl.program_id(axis = 0)
     num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
     num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
     num_pid_in_group = GROUP_SIZE_M * num_pid_n
@@ -185,8 +202,8 @@ def _w8a8_block_fp8_matmul(
 
     accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype = tl.float32)
     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
-        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
-        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
+        a = tl.load(a_ptrs, mask = offs_k[None, :] < K - k * BLOCK_SIZE_K, other = 0.0)
+        b = tl.load(b_ptrs, mask = offs_k[:, None] < K - k * BLOCK_SIZE_K, other = 0.0)
 
         k_start = k * BLOCK_SIZE_K
         offs_ks = k_start // group_k
@@ -209,7 +226,7 @@ def _w8a8_block_fp8_matmul(
     c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
     c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
     tl.store(c_ptrs, c, mask = c_mask)
-pass
+
 
 def w8a8_block_fp8_matmul_triton(
     A: torch.Tensor,
@@ -259,7 +276,9 @@ def w8a8_block_fp8_matmul_triton(
     BLOCK_SIZE_N = block_n
 
     def grid(META):
-        return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
+        return (
+            triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
+        )
 
     _w8a8_block_fp8_matmul[grid](
         A,
@@ -288,7 +307,7 @@ def grid(META):
         GROUP_SIZE_M = 8,
     )
     return C
-pass
+
 
 def torchao_block_matmul(
     act_q: torch.Tensor,
@@ -303,31 +322,42 @@ def torchao_block_matmul(
         act_scale.contiguous(),
         weight_q.contiguous(),
         weight_scale.contiguous(),
-        block_size=block_size[1],
+        block_size = block_size[1],
     )
     return out.to(output_dtype)
-pass
+
 
 # This torchao FP8 matmul seems to be ~3x faster than the w8a8_block_fp8_matmul_triton. Though this is 15-30% slower than fbgemm implementation.
 # But this gives very comparable results when it comes to training loss, so we prefer using it when available.
-fp8_block_matmul = torchao_block_matmul if torchao_blockwise_gemm is not None else w8a8_block_fp8_matmul_triton
+fp8_block_matmul = (
+    torchao_block_matmul
+    if torchao_blockwise_gemm is not None
+    else w8a8_block_fp8_matmul_triton
+)
 
-class FP8BlockQuantLinear(torch.autograd.Function):
 
+class FP8BlockQuantLinear(torch.autograd.Function):
     @staticmethod
     def forward(ctx, X, weight, weight_scale):
         # block_size = getattr(weight, 'block_size', [128,128])
         m, n = weight.shape
         p, q = weight_scale.shape
-        block_size = getattr(weight, 'block_size', None) or getattr(weight_scale, 'block_size', None)
+        block_size = getattr(weight, "block_size", None) or getattr(
+            weight_scale, "block_size", None
+        )
         assert block_size is not None, "block_size is not set"
         if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q:
-            if triton.cdiv(m, block_size[0]) == q and triton.cdiv(n, block_size[1]) == p:
+            if (
+                triton.cdiv(m, block_size[0]) == q
+                and triton.cdiv(n, block_size[1]) == p
+            ):
                 # weights are tranposed during backward pass for training :)
                 # We tranpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X
                 weight_scale = weight_scale.T
             else:
-                raise ValueError(f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}")
+                raise ValueError(
+                    f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}"
+                )
 
         if not weight.is_contiguous():
             weight = weight.contiguous()
@@ -339,7 +369,7 @@ def forward(ctx, X, weight, weight_scale):
             scale,
             weight_scale,
             block_size,
-            output_dtype=X.dtype,
+            output_dtype = X.dtype,
         )
         ctx.weight = weight
         ctx.weight_scale = weight_scale
@@ -353,17 +383,18 @@ def backward(ctx, grad_output):
         del W_deq
         return grad_X, None, None
 
+
 @torch_compile
 def fp8_block_quant_forward(X, weight, weight_scale):
     return FP8BlockQuantLinear.apply(X, weight, weight_scale)
 
 
 class FbgemmFp8Linear_matmul(torch.autograd.Function):
-
     @staticmethod
-    def forward(ctx, x, weight, weight_scale, bias=None):
-
-        if weight.shape[0] == weight_scale.shape[0] and (weight.shape[0] % 8 == 0 and weight.shape[1] % 8 == 0):
+    def forward(ctx, x, weight, weight_scale, bias = None):
+        if weight.shape[0] == weight_scale.shape[0] and (
+            weight.shape[0] % 8 == 0 and weight.shape[1] % 8 == 0
+        ):
             # Edit: The kernel seems to expect that the weight has dimensions divisible by 8. Otherwise it throws `RuntimeError: cutlass cannot implement`
             # One thing we can do is to pad the weight and weight scale to multiple of 8 and perform a F8F8BF16 operation.
             # I tried benchmarking that for speed but observed that dequantize+bf16 matmul is significantly faster than padding+f8f8bf16 matmul. So we'll go that route.
@@ -374,7 +405,8 @@ def forward(ctx, x, weight, weight_scale, bias=None):
             # x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
             # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
             x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
-                x.view(-1, x.shape[-1]).contiguous(), scale_ub = getattr(weight, 'input_scale_ub', None)
+                x.view(-1, x.shape[-1]).contiguous(),
+                scale_ub = getattr(weight, "input_scale_ub", None),
             )
             # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
             # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
@@ -395,7 +427,10 @@ def forward(ctx, x, weight, weight_scale, bias=None):
             output = output.to(x.device, x.dtype)
             output = output.reshape(output_shape)
             del x_quantized, x_scale
-        elif (weight.shape[0] != weight_scale.shape[0] and weight.shape[1] == weight_scale.shape[0]) or (weight.shape[0] // 8 != 0 or weight.shape[1] // 8 != 0):
+        elif (
+            weight.shape[0] != weight_scale.shape[0]
+            and weight.shape[1] == weight_scale.shape[0]
+        ) or (weight.shape[0] // 8 != 0 or weight.shape[1] // 8 != 0):
             # Either the weight/scale is transposed or its shape is not divisible by 8. Both cases, dequantizing is the preferred way.
             # The transpose case is generally noticed in backward pass when we do dY@W instead of @W.T as we do for forward.
             # The shape case, I noticed to happen in MLP of Qwen 2.5 VL 7B where the gate proj is of shape (3420, 1280) and 3420/8=427.5
@@ -404,7 +439,9 @@ def forward(ctx, x, weight, weight_scale, bias=None):
             output = torch_matmul(x, W_deq)
             del W_deq
         else:
-            raise ValueError(f"Shapes are incompatible {weight.shape=}, {weight_scale.shape=}, {x.shape=}")
+            raise ValueError(
+                f"Shapes are incompatible {weight.shape = }, {weight_scale.shape = }, {x.shape = }"
+            )
 
         ctx.weight = weight
         ctx.weight_scale = weight_scale
@@ -417,19 +454,21 @@ def backward(ctx, grad_output):
         del W_deq
         return grad_X, None, None, None, None
 
+
 @torch_compile
-def fbgemm_fp8_linear(X, weight, weight_scale, bias=None):
+def fbgemm_fp8_linear(X, weight, weight_scale, bias = None):
     return FbgemmFp8Linear_matmul.apply(X, weight, weight_scale, bias)
 
 
 class FP8_torch_linear(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, X, weight, weight_scale, bias=None):
-
+    def forward(ctx, X, weight, weight_scale, bias = None):
         orig_shape = X.shape
         X = X.view(-1, X.shape[-1])
 
-        bs_n, bs_k = getattr(weight, 'block_size', None) or getattr(weight_scale, 'block_size', [128, 128])
+        bs_n, bs_k = getattr(weight, "block_size", None) or getattr(
+            weight_scale, "block_size", [128, 128]
+        )
         bs_m = bs_n
 
         m, n = weight.shape
@@ -441,14 +480,18 @@ def forward(ctx, X, weight, weight_scale, bias=None):
                 # We tranpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X
                 weight_scale = weight_scale.T
             else:
-                raise ValueError(f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}")
+                raise ValueError(
+                    f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}"
+                )
 
         xq, xs = triton_quantize_fp8_block(X, bs_m, bs_n, None)
         ## TODO: Investigate and resolve the high divergence of this output from baseline
         # WARNING: This causes the outputs to diverge from expected when X has high values in it.
         # That results in the model producing gibberish, especially on longer sequences and training loss starting at high values like 8 instead of <1 ideally
         # Please refrain from using this till this issue is resolved. This exists here just for a future headstart.
-        output = torch.ops.fbgemm.f8f8bf16_blockwise(xq, weight.contiguous(), xs, weight_scale.contiguous(), bs_m, bs_n, bs_k)
+        output = torch.ops.fbgemm.f8f8bf16_blockwise(
+            xq, weight.contiguous(), xs, weight_scale.contiguous(), bs_m, bs_n, bs_k
+        )
         output = output + bias if bias is not None else output
 
         output = output.view(*orig_shape[:-1], -1)
@@ -468,13 +511,14 @@ def backward(ctx, grad_output):
         del W_deq
         return grad_X, None, None, None, None
 
+
 @torch_compile
-def fp8_torch_linear(X, weight, weight_scale, bias=None):
+def fp8_torch_linear(X, weight, weight_scale, bias = None):
     return FP8_torch_linear.apply(X, weight, weight_scale, bias)
 
 
 @torch_compile
-def fp8_linear(X, weight, weight_scale, bias=None):
+def fp8_linear(X, weight, weight_scale, bias = None):
     if weight_scale.ndim == 2 and weight_scale.shape[1] > 1:
         # This is block quantized FP8 matmul
         out = fp8_block_quant_forward(X, weight, weight_scale)
@@ -488,14 +532,17 @@ def fp8_linear(X, weight, weight_scale, bias=None):
     return out
 
 
-def module_forward_patch(forward_function, scale_attr='weight_scale'):
+def module_forward_patch(forward_function, scale_attr = "weight_scale"):
     def patched_forward(self, X):
         return forward_function(X, self.weight, getattr(self, scale_attr))
+
     return patched_forward
 
 
 # Patch the forward functions of the layers (for compiled models)
 if FbgemmFp8Linear is not None:
-    FbgemmFp8Linear.forward = module_forward_patch(fbgemm_fp8_linear, 'weight_scale')
+    FbgemmFp8Linear.forward = module_forward_patch(fbgemm_fp8_linear, "weight_scale")
 if FP8Linear is not None:
-    FP8Linear.forward = module_forward_patch(fp8_block_quant_forward, 'weight_scale_inv')
+    FP8Linear.forward = module_forward_patch(
+        fp8_block_quant_forward, "weight_scale_inv"
+    )
diff --git a/unsloth/kernels/geglu.py b/unsloth/kernels/geglu.py
index 67f576df3..67c36dd8a 100644
--- a/unsloth/kernels/geglu.py
+++ b/unsloth/kernels/geglu.py
@@ -23,23 +23,28 @@
 
 
 @triton.jit
-def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
+def _exact_forward_kernel(
+    e,
+    g,
+    h,
+    n_elements,
+    BLOCK_SIZE: tl.constexpr,
+):
     block_idx = tl.program_id(0)
-    offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     mask = offsets < n_elements
 
     # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
     # h = f * up
     e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
-    g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
+    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)
 
     f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
-    f_row = f_row.to(g_row.dtype) # Exact copy from HF
+    f_row = f_row.to(g_row.dtype)  # Exact copy from HF
     h_row = f_row * g_row
 
     # Store h
     tl.store(h + offsets, h_row, mask = mask)
-pass
 
 
 def geglu_exact_forward_kernel(gate, up):
@@ -47,15 +52,26 @@ def geglu_exact_forward_kernel(gate, up):
     n_elements = gate.numel()
     device = gate.device
     out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
-    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
     with torch_gpu_device(device):
-        _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
+        _exact_forward_kernel[grid](
+            gate,
+            up,
+            out,
+            n_elements,
+            BLOCK_SIZE = 1024,
+        )
     return out
-pass
 
 
 @triton.jit
-def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
+def _exact_backward_kernel(
+    DW,
+    e,
+    g,
+    n_elements,
+    BLOCK_SIZE: tl.constexpr,
+):
     """
     f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
     h = f * up
@@ -67,74 +83,82 @@ def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
     f =        1/2 * (1 + erf(1/sqrt(2) * e)) * e
     """
     block_idx = tl.program_id(0)
-    offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     mask = offsets < n_elements
 
-    DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
-    e_row  = tl.load(e  + offsets, mask = mask, other = 0).to(tl.float32)
-    g_row  = tl.load(g  + offsets, mask = mask, other = 0)#.to(tl.float32)
+    DW_row = tl.load(DW + offsets, mask = mask, other = 0)  # .to(tl.float32)
+    e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)
 
     # Break e_row away for re-use
     # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
     f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
     f_row = f_partial_row * e_row
-    
+
     f_row = f_row.to(DW_row.dtype)
     # h = f * g
-    h_row  =  f_row * g_row
+    h_row = f_row * g_row
     # df = DW * f
     df_row = DW_row * f_row
     # dg = DW * g
     dg_row = DW_row * g_row
 
     # df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
-    t = 0.3989422804014327 # 1/sqrt(2*pi)
+    t = 0.3989422804014327  # 1/sqrt(2*pi)
     df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
 
     de_row = dg_row.to(tl.float32) * df_de
     de_row = de_row.to(DW_row.dtype)
 
     # Store derivatives in buffers
-    tl.store(DW + offsets, h_row,  mask = mask) # h  = f * g
-    tl.store(e  + offsets, df_row, mask = mask) # df = DW * f
-    tl.store(g  + offsets, de_row, mask = mask) # de
-pass
+    tl.store(DW + offsets, h_row, mask = mask)  # h  = f * g
+    tl.store(e + offsets, df_row, mask = mask)  # df = DW * f
+    tl.store(g + offsets, de_row, mask = mask)  # de
 
 
 def geglu_exact_backward_kernel(DW, e, g):
     batch_seq_len, hd = e.shape
     n_elements = e.numel()
-    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
     with torch_gpu_device(e.device):
-        _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
+        _exact_backward_kernel[grid](
+            DW,
+            e,
+            g,
+            n_elements,
+            BLOCK_SIZE = 1024,
+        )
     return DW, e, g
-pass
 
 
 @triton.jit
-def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
+def _approx_forward_kernel(
+    e,
+    g,
+    h,
+    n_elements,
+    BLOCK_SIZE: tl.constexpr,
+):
     block_idx = tl.program_id(0)
-    offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     mask = offsets < n_elements
 
     # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
     # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
     # h = f * up
-    s = 0.7978845608028654 # math.sqrt(2 / math.pi)
-    
+    s = 0.7978845608028654  # math.sqrt(2 / math.pi)
+
     e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
-    g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
+    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)
 
-    f_row = 0.5 * e_row * (
-        triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
-        + 1.0
+    f_row = (
+        0.5 * e_row * (triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) + 1.0)
     )
-    f_row = f_row.to(g_row.dtype) # Exact copy from HF
+    f_row = f_row.to(g_row.dtype)  # Exact copy from HF
     h_row = f_row * g_row
 
     # Store h
     tl.store(h + offsets, h_row, mask = mask)
-pass
 
 
 def geglu_approx_forward_kernel(gate, up):
@@ -142,15 +166,26 @@ def geglu_approx_forward_kernel(gate, up):
     n_elements = gate.numel()
     device = gate.device
     out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
-    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
     with torch_gpu_device(device):
-        _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
+        _approx_forward_kernel[grid](
+            gate,
+            up,
+            out,
+            n_elements,
+            BLOCK_SIZE = 1024,
+        )
     return out
-pass
 
 
 @triton.jit
-def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
+def _approx_backward_kernel(
+    DW,
+    e,
+    g,
+    n_elements,
+    BLOCK_SIZE: tl.constexpr,
+):
     """
     f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
     h = f * up
@@ -166,28 +201,28 @@ def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
     See https://www.desmos.com/calculator/nqprfoni6x
     """
     block_idx = tl.program_id(0)
-    offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     mask = offsets < n_elements
 
-    DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
-    e_row  = tl.load(e  + offsets, mask = mask, other = 0).to(tl.float32)
-    g_row  = tl.load(g  + offsets, mask = mask, other = 0)#.to(tl.float32)
+    DW_row = tl.load(DW + offsets, mask = mask, other = 0)  # .to(tl.float32)
+    e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)
 
     # See https://www.desmos.com/calculator/nqprfoni6x
-    s = 0.7978845608028654 # math.sqrt(2 / math.pi)
-    a = s * e_row # a = sqrt(2 / pi) * x
-    b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
+    s = 0.7978845608028654  # math.sqrt(2 / math.pi)
+    a = s * e_row  # a = sqrt(2 / pi) * x
+    b = a * 0.044715 * e_row * e_row  # b = a * 0.044715 * x^2
     T = 1.0 + triton_tanh(a + b)
     T2 = 0.5 * T
     # Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
-    Q2 = -T2 * (T - 2.0) * (a + 3.0 * b) 
-    df_de = T2 + Q2 # 1/2 * (T + Q)
+    Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
+    df_de = T2 + Q2  # 1/2 * (T + Q)
 
     # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
     f_row = T2 * e_row
     f_row = f_row.to(DW_row.dtype)
     # h = f * g
-    h_row  =  f_row * g_row
+    h_row = f_row * g_row
     # df = DW * f
     df_row = DW_row * f_row
     # dg = DW * g
@@ -197,17 +232,21 @@ def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
     de_row = de_row.to(DW_row.dtype)
 
     # Store derivatives in buffers
-    tl.store(DW + offsets, h_row,  mask = mask) # h  = f * g
-    tl.store(e  + offsets, df_row, mask = mask) # df = DW * f
-    tl.store(g  + offsets, de_row, mask = mask) # de
-pass
+    tl.store(DW + offsets, h_row, mask = mask)  # h  = f * g
+    tl.store(e + offsets, df_row, mask = mask)  # df = DW * f
+    tl.store(g + offsets, de_row, mask = mask)  # de
 
 
 def geglu_approx_backward_kernel(DW, e, g):
     batch_seq_len, hd = e.shape
     n_elements = e.numel()
-    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
     with torch_gpu_device(e.device):
-        _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
+        _approx_backward_kernel[grid](
+            DW,
+            e,
+            g,
+            n_elements,
+            BLOCK_SIZE = 1024,
+        )
     return DW, e, g
-pass
diff --git a/unsloth/kernels/layernorm.py b/unsloth/kernels/layernorm.py
index f01c4ffb0..5e2e3af2f 100644
--- a/unsloth/kernels/layernorm.py
+++ b/unsloth/kernels/layernorm.py
@@ -24,23 +24,25 @@
 
 @triton.jit
 def layernorm_forward(
-    Y, Y_row_stride,
-    X, X_row_stride,
+    Y,
+    Y_row_stride,
+    X,
+    X_row_stride,
     W,
     b,
     r,
     mu,
-    n_cols : tl.constexpr,
-    eps : tl.constexpr,
-    BLOCK_SIZE : tl.constexpr
+    n_cols: tl.constexpr,
+    eps: tl.constexpr,
+    BLOCK_SIZE: tl.constexpr,
 ):
     row_idx = tl.program_id(0)
     col_offsets = tl.arange(0, BLOCK_SIZE)
     mask = col_offsets < n_cols
 
-    Y  += row_idx * Y_row_stride
-    X  += row_idx * X_row_stride
-    r  += row_idx
+    Y += row_idx * Y_row_stride
+    X += row_idx * X_row_stride
+    r += row_idx
     mu += row_idx
 
     # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
@@ -49,29 +51,30 @@ def layernorm_forward(
     W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
     b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
 
-    mean_X  = tl.sum(X_row,   axis = 0) / n_cols
+    mean_X = tl.sum(X_row, axis = 0) / n_cols
     # (X[0] - mean) == -mean so we need to mask it out
     XX = tl.where(mask, X_row - mean_X, 0)
     row_var = tl.sum(XX * XX, axis = 0) / n_cols
     inv_var = tl.math.rsqrt(row_var + eps)
-    tl.store (r, inv_var)
-    tl.store (mu, mean_X)
+    tl.store(r, inv_var)
+    tl.store(mu, mean_X)
     output = (XX * inv_var) * W_row + b_row
     tl.store(Y + col_offsets, output, mask = mask)
-pass
 
 
 @triton.jit
 def layernorm_backward(
-    dY, dY_row_stride,
-    X,   X_row_stride,
+    dY,
+    dY_row_stride,
+    X,
+    X_row_stride,
     W,
     b,
     r,
     mu,
-    n_cols : tl.constexpr,
-    eps : tl.constexpr,
-    BLOCK_SIZE : tl.constexpr
+    n_cols: tl.constexpr,
+    eps: tl.constexpr,
+    BLOCK_SIZE: tl.constexpr,
 ):
     # Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
     row_idx = tl.program_id(0)
@@ -79,25 +82,28 @@ def layernorm_backward(
     mask = col_offsets < n_cols
 
     dY += row_idx * dY_row_stride
-    X  += row_idx *  X_row_stride
-    r  += row_idx
+    X += row_idx * X_row_stride
+    r += row_idx
     mu += row_idx
 
     # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
     # are in float32!
     dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
-    X_row  = tl.load(X  + col_offsets, mask = mask, other = 0).to(tl.float32)
-    W_row  = tl.load(W  + col_offsets, mask = mask, other = 0).to(tl.float32)
-    b_row  = tl.load(b  + col_offsets, mask = mask, other = 0).to(tl.float32)
+    X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+    W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
+    b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
 
-    inv_var = tl.load(r) .to(tl.float32)
-    mean    = tl.load(mu).to(tl.float32)
-    normed  = (X_row - mean) * inv_var
+    inv_var = tl.load(r).to(tl.float32)
+    mean = tl.load(mu).to(tl.float32)
+    normed = (X_row - mean) * inv_var
     dY_W = dY_row * W_row
-    dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
+    dX_row = (
+        dY_W
+        - tl.sum(dY_W, axis = 0) / n_cols
+        - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
+    )
     dX_row = dX_row * inv_var
     tl.store(dY + col_offsets, dX_row, mask = mask)
-pass
 
 
 class Fast_Layernorm(torch.autograd.Function):
@@ -109,28 +115,30 @@ def forward(ctx, X, W, b, eps):
         n_rows, n_cols = X.shape
         BLOCK_SIZE, num_warps = calculate_settings(n_cols)
         device = X.device
-        Y  = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
-        r  = torch.empty(n_rows, dtype = torch.float32, device = device)
+        Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
+        r = torch.empty(n_rows, dtype = torch.float32, device = device)
         mu = torch.empty(n_rows, dtype = torch.float32, device = device)
 
         with torch_gpu_device(device):
             layernorm_forward[(n_rows,)](
-                Y, Y.stride(0),
-                X, X.stride(0),
+                Y,
+                Y.stride(0),
+                X,
+                X.stride(0),
                 W,
                 b,
                 r,
                 mu,
-                n_cols, eps,
+                n_cols,
+                eps,
                 BLOCK_SIZE = BLOCK_SIZE,
-                num_warps  = num_warps,
+                num_warps = num_warps,
             )
         ctx.eps = eps
         ctx.BLOCK_SIZE = BLOCK_SIZE
-        ctx.num_warps  = num_warps
+        ctx.num_warps = num_warps
         ctx.save_for_backward(X, W, b, r, mu)
         return Y.view(*shape)
-    pass
 
     @staticmethod
     def backward(ctx, dY):
@@ -142,40 +150,46 @@ def backward(ctx, dY):
 
         with torch_gpu_device(dY.device):
             layernorm_backward[(n_rows,)](
-                dY, dY.stride(0),
-                X,  X .stride(0),
+                dY,
+                dY.stride(0),
+                X,
+                X.stride(0),
                 W,
                 b,
                 r,
                 mu,
-                n_cols, ctx.eps,
+                n_cols,
+                ctx.eps,
                 BLOCK_SIZE = ctx.BLOCK_SIZE,
-                num_warps  = ctx.num_warps,
+                num_warps = ctx.num_warps,
             )
         dX = dY.view(*shape)
         return dX, None, None, None, None
-    pass
-pass
 
 
 def fast_layernorm(layernorm, X):
-    assert(layernorm.elementwise_affine is True)
-    W    = layernorm.weight
+    assert layernorm.elementwise_affine is True
+    W = layernorm.weight
     bias = layernorm.bias
-    eps = layernorm.variance_epsilon if \
-        hasattr(layernorm, "variance_epsilon") \
+    eps = (
+        layernorm.variance_epsilon
+        if hasattr(layernorm, "variance_epsilon")
         else layernorm.eps
+    )
     out = Fast_Layernorm.apply(X, W, bias, eps)
     return out
-pass
-
 
 
 def test_layernorm(
-    dim = 1024, eps = 1e-5, dtype = torch.float16,
-    bsz = 21, random_state = 3407, seqlen = 3341,
+    dim = 1024,
+    eps = 1e-5,
+    dtype = torch.float16,
+    bsz = 21,
+    random_state = 3407,
+    seqlen = 3341,
 ):
     from torch.nn import LayerNorm
+
     layernorm = LayerNorm((dim,), eps = eps, device = "cuda", dtype = dtype)
     torch.cuda.manual_seed(random_state)
     torch.manual_seed(random_state)
@@ -183,7 +197,7 @@ def test_layernorm(
     torch.nn.init.uniform_(layernorm.bias)
     X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
     XX = X.clone()
-    X .requires_grad_(True)
+    X.requires_grad_(True)
     XX.requires_grad_(True)
     Y = layernorm(X)
     YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
@@ -192,8 +206,7 @@ def test_layernorm(
     # from unsloth.kernels import fast_layernorm
     Y = fast_layernorm(layernorm, XX)
     Y.backward(YY)
-    assert(torch.dist(correct_grad, XX.grad).item() <= 0.1)
-pass
+    assert torch.dist(correct_grad, XX.grad).item() <= 0.1
 
 
 def testing_suite_layernorm():
@@ -210,9 +223,3 @@ def testing_suite_layernorm():
                             random_state = random_state,
                             seqlen = seqlen,
                         )
-                    pass
-                pass
-            pass
-        pass
-    pass
-pass
diff --git a/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py b/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py
index 2fe2afa1e..074cc5a56 100644
--- a/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py
+++ b/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py
@@ -56,7 +56,7 @@ def run_benchmark_forward(
     hidden_size = config.hidden_size
 
     X = torch.randn(
-        bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True
+        bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
     )
 
     # Forward
@@ -89,7 +89,7 @@ def run_benchmark_backward(
     config: AutoConfig,
     seqlen: int,
     dtype: torch.dtype,
-    bs=1,
+    bs = 1,
 ):
     torch.manual_seed(
         SEED
@@ -98,7 +98,7 @@ def run_benchmark_backward(
     hidden_size = config.hidden_size
 
     X = torch.randn(
-        bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True
+        bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
     )
     X_test = X.detach().clone().requires_grad_(True)
 
@@ -114,14 +114,14 @@ def run_benchmark_backward(
 
     # Bench
     grad_output = torch.randn_like(output)
-    bench_backward_ref = lambda: output.backward(grad_output, retain_graph=True)  # noqa: E731
-    bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph=True)  # noqa: E731
+    bench_backward_ref = lambda: output.backward(grad_output, retain_graph = True)  # noqa: E731
+    bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph = True)  # noqa: E731
 
     ref_backward_time = do_bench(
-        bench_backward_ref, grad_to_none=[X, *ref_model.parameters()]
+        bench_backward_ref, grad_to_none = [X, *ref_model.parameters()]
     )
     fused_backward_time = do_bench(
-        bench_backward_fused, grad_to_none=[X_test, *tt_model.parameters()]
+        bench_backward_fused, grad_to_none = [X_test, *tt_model.parameters()]
     )
     print(
         f"Backward: ref {ref_backward_time:.4f}, fused {fused_backward_time:.4f}, speedup {ref_backward_time / fused_backward_time:.1f}x"
@@ -138,10 +138,10 @@ def setup_model(
     kernel_config_fwd,
     kernel_config_bwd_dW,
     kernel_config_bwd_dX,
-    dX_only=False,
-    dW_only=False,
-    overlap_router_shared=False,
-    device="cuda",
+    dX_only = False,
+    dW_only = False,
+    overlap_router_shared = False,
+    device = "cuda",
 ):
     if isinstance(config, Qwen3MoeConfig):
         ref_model = Qwen3MoeSparseMoeBlock(config).to(device, dtype)
@@ -149,29 +149,29 @@ def setup_model(
         # Triton kernel grouped gemm version of MoE Block -- this is what we're testing
         tt_model = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
             ref_model,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            autotune=autotune,
-            kernel_config_fwd=kernel_config_fwd,
-            kernel_config_bwd_dW=kernel_config_bwd_dW,
-            kernel_config_bwd_dX=kernel_config_bwd_dX,
-            dX_only=dX_only,
-            dW_only=dW_only,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            autotune = autotune,
+            kernel_config_fwd = kernel_config_fwd,
+            kernel_config_bwd_dW = kernel_config_bwd_dW,
+            kernel_config_bwd_dX = kernel_config_bwd_dX,
+            dX_only = dX_only,
+            dW_only = dW_only,
         ).to(device, dtype)
 
     elif isinstance(config, Llama4TextConfig):
         ref_model = Llama4TextMoe(config).to(device, dtype)
         tt_model = Llama4TritonTextMoe(
             config,
-            overlap_router_shared=overlap_router_shared,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            autotune=autotune,
-            kernel_config_fwd=kernel_config_fwd,
-            kernel_config_bwd_dW=kernel_config_bwd_dW,
-            kernel_config_bwd_dX=kernel_config_bwd_dX,
-            dX_only=dX_only,
-            dW_only=dW_only,
+            overlap_router_shared = overlap_router_shared,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            autotune = autotune,
+            kernel_config_fwd = kernel_config_fwd,
+            kernel_config_bwd_dW = kernel_config_bwd_dW,
+            kernel_config_bwd_dX = kernel_config_bwd_dX,
+            dX_only = dX_only,
+            dW_only = dW_only,
         ).to(device, dtype)
 
     else:
@@ -205,31 +205,31 @@ def run_benchmark(
 
     ref_model, tt_model = setup_model(
         model_config,
-        dtype=dtype,
-        permute_x=permute_x,
-        permute_y=permute_y,
-        autotune=autotune,
-        kernel_config_fwd=kernel_config_fwd,
-        kernel_config_bwd_dW=kernel_config_bwd_dW,
-        kernel_config_bwd_dX=kernel_config_bwd_dX,
-        dX_only=dX_only,
-        dW_only=dW_only,
-        overlap_router_shared=overlap_router_shared,
+        dtype = dtype,
+        permute_x = permute_x,
+        permute_y = permute_y,
+        autotune = autotune,
+        kernel_config_fwd = kernel_config_fwd,
+        kernel_config_bwd_dW = kernel_config_bwd_dW,
+        kernel_config_bwd_dX = kernel_config_bwd_dX,
+        dX_only = dX_only,
+        dW_only = dW_only,
+        overlap_router_shared = overlap_router_shared,
     )
 
     if mode == "forward":
         ref_time, fused_time = run_benchmark_forward(
             ref_model,
             tt_model,
-            config=model_config,
-            seqlen=seqlen,
-            dtype=dtype,
-            autotune=autotune,
-            kernel_config_fwd=kernel_config_fwd,
+            config = model_config,
+            seqlen = seqlen,
+            dtype = dtype,
+            autotune = autotune,
+            kernel_config_fwd = kernel_config_fwd,
         )
     else:
         ref_time, fused_time = run_benchmark_backward(
-            ref_model, tt_model, config=model_config, seqlen=seqlen, dtype=dtype
+            ref_model, tt_model, config = model_config, seqlen = seqlen, dtype = dtype
         )
 
     if autotune:
@@ -251,60 +251,60 @@ def run_benchmark(
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--results_dir", type=str, default="benchmark_results")
-    parser.add_argument("--model", type=str, choices=["llama4", "qwen3"], required=True)
-    parser.add_argument("--seqlen", type=int, default=1024)
+    parser.add_argument("--results_dir", type = str, default = "benchmark_results")
+    parser.add_argument("--model", type = str, choices = ["llama4", "qwen3"], required = True)
+    parser.add_argument("--seqlen", type = int, default = 1024)
     parser.add_argument(
-        "--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16"
+        "--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
     )
-    parser.add_argument("--permute_x", action="store_true")
-    parser.add_argument("--permute_y", action="store_true")
-    parser.add_argument("--autotune", action="store_true")
-    parser.add_argument("--overlap_router_shared", action="store_true")
+    parser.add_argument("--permute_x", action = "store_true")
+    parser.add_argument("--permute_y", action = "store_true")
+    parser.add_argument("--autotune", action = "store_true")
+    parser.add_argument("--overlap_router_shared", action = "store_true")
     parser.add_argument(
         "--BLOCK_SIZE_M",
-        nargs=2,
-        type=int,
-        default=[DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]],
+        nargs = 2,
+        type = int,
+        default = [DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]],
     )
     parser.add_argument(
         "--BLOCK_SIZE_N",
-        nargs=2,
-        type=int,
-        default=[DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]],
+        nargs = 2,
+        type = int,
+        default = [DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]],
     )
     parser.add_argument(
         "--BLOCK_SIZE_K",
-        nargs=2,
-        type=int,
-        default=[DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]],
+        nargs = 2,
+        type = int,
+        default = [DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]],
     )
     parser.add_argument(
         "--num_warps",
-        nargs=2,
-        type=int,
-        default=[DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]],
+        nargs = 2,
+        type = int,
+        default = [DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]],
     )
     parser.add_argument(
         "--num_stages",
-        nargs=2,
-        type=int,
-        default=[DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]],
+        nargs = 2,
+        type = int,
+        default = [DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]],
     )
     parser.add_argument(
-        "--use_tma_load_w", action="store_true"
+        "--use_tma_load_w", action = "store_true"
     )  # No need to specify, will automatically parametrize these for each kernel config
     parser.add_argument(
-        "--use_tma_load_x", action="store_true"
+        "--use_tma_load_x", action = "store_true"
     )  # No need to specify, will automatically parametrize these for each kernel config
     parser.add_argument(
-        "--use_tma_load_dy", action="store_true"
+        "--use_tma_load_dy", action = "store_true"
     )  # No need to specify, will automatically parametrize these for each kernel config
     parser.add_argument(
         "--mode",
-        type=str,
-        choices=["forward", "backward", "dW", "dX"],
-        default="forward",
+        type = str,
+        choices = ["forward", "backward", "dW", "dX"],
+        default = "forward",
     )
     args = parser.parse_args()
     args.dtype = getattr(torch, args.dtype)
@@ -324,13 +324,13 @@ def run_benchmark(
         ref_time, fused_time = run_benchmark(
             args.mode,
             model_config,
-            seqlen=args.seqlen,
-            dtype=args.dtype,
-            permute_x=args.permute_x,
-            permute_y=args.permute_y,
-            autotune=args.autotune,
-            overlap_router_shared=args.overlap_router_shared,
-            results_dir=args.results_dir,
+            seqlen = args.seqlen,
+            dtype = args.dtype,
+            permute_x = args.permute_x,
+            permute_y = args.permute_y,
+            autotune = args.autotune,
+            overlap_router_shared = args.overlap_router_shared,
+            results_dir = args.results_dir,
         )
         end_time = time.time()
         print(f"Total time: {end_time - start_time:.4f} seconds")
@@ -343,13 +343,13 @@ def run_benchmark(
         kernel_configs = create_kernel_configs(args, args.permute_x, args.permute_y)
         print(f"Running {len(kernel_configs)} kernel configs")
         default_kernel_config_fwd = KernelConfigForward(
-            permute_x=args.permute_x, permute_y=args.permute_y
+            permute_x = args.permute_x, permute_y = args.permute_y
         )
         default_kernel_config_bwd_dW = KernelConfigBackward_dW(
-            permute_x=args.permute_x, permute_y=args.permute_y
+            permute_x = args.permute_x, permute_y = args.permute_y
         )
         default_kernel_config_bwd_dX = KernelConfigBackward_dX(
-            permute_x=args.permute_x, permute_y=args.permute_y
+            permute_x = args.permute_x, permute_y = args.permute_y
         )
         results = []
         for kernel_config in kernel_configs:
@@ -374,21 +374,21 @@ def run_benchmark(
             ref_time, fused_time = run_benchmark(
                 args.mode,
                 model_config,
-                seqlen=args.seqlen,
-                dtype=args.dtype,
-                permute_x=kernel_config.permute_x,
-                permute_y=kernel_config.permute_y,
-                autotune=False,
-                kernel_config_fwd=kernel_config_fwd,
-                kernel_config_bwd_dW=kernel_config_bwd_dW,
-                kernel_config_bwd_dX=kernel_config_bwd_dX,
+                seqlen = args.seqlen,
+                dtype = args.dtype,
+                permute_x = kernel_config.permute_x,
+                permute_y = kernel_config.permute_y,
+                autotune = False,
+                kernel_config_fwd = kernel_config_fwd,
+                kernel_config_bwd_dW = kernel_config_bwd_dW,
+                kernel_config_bwd_dX = kernel_config_bwd_dX,
             )
             results.append(
                 KernelResult(
-                    torch_time=ref_time,
-                    triton_time=fused_time,
-                    speedup=ref_time / fused_time,
-                    kernel_config=kernel_config,
+                    torch_time = ref_time,
+                    triton_time = fused_time,
+                    speedup = ref_time / fused_time,
+                    kernel_config = kernel_config,
                 )
             )
         df = post_process_results(
diff --git a/unsloth/kernels/moe/benchmark/utils.py b/unsloth/kernels/moe/benchmark/utils.py
index 3c4fd705d..21905d8df 100644
--- a/unsloth/kernels/moe/benchmark/utils.py
+++ b/unsloth/kernels/moe/benchmark/utils.py
@@ -44,7 +44,7 @@ def post_process_results(
     dtype: torch.dtype,
     autotune: bool,
 ):
-    df = KernelResult.to_dataframe(results, sort_by="speedup")
+    df = KernelResult.to_dataframe(results, sort_by = "speedup")
     df = create_merged_results(df, mode, seqlen, dtype, autotune)
     return df
 
@@ -63,16 +63,16 @@ def save_results(
     if not os.path.exists(save_dir):
         os.makedirs(save_dir)
     print(f"Saving results to {save_path}")
-    df.to_csv(save_path, index=False)
+    df.to_csv(save_path, index = False)
 
 
 def create_kernel_configs(args: argparse.Namespace, permute_x: bool, permute_y: bool):
     block_m_range = power_of_two_range(args.BLOCK_SIZE_M[0], args.BLOCK_SIZE_M[1])
     block_n_range = power_of_two_range(args.BLOCK_SIZE_N[0], args.BLOCK_SIZE_N[1])
     block_k_range = power_of_two_range(args.BLOCK_SIZE_K[0], args.BLOCK_SIZE_K[1])
-    num_warps_range = multiples_of_range(args.num_warps[0], args.num_warps[1], step=2)
+    num_warps_range = multiples_of_range(args.num_warps[0], args.num_warps[1], step = 2)
     num_stages_range = multiples_of_range(
-        args.num_stages[0], args.num_stages[1], step=1
+        args.num_stages[0], args.num_stages[1], step = 1
     )
 
     mode = args.mode
@@ -96,39 +96,39 @@ def create_kernel_configs(args: argparse.Namespace, permute_x: bool, permute_y:
     ):
         if mode == "forward":
             kernel_config = KernelConfigForward(
-                BLOCK_SIZE_M=block_m,
-                BLOCK_SIZE_N=block_n,
-                BLOCK_SIZE_K=block_k,
-                num_warps=num_warps,
-                num_stages=num_stages,
-                use_tma_load_w=tma_load_a,
-                use_tma_load_x=tma_load_b,
-                permute_x=permute_x,
-                permute_y=permute_y,
+                BLOCK_SIZE_M = block_m,
+                BLOCK_SIZE_N = block_n,
+                BLOCK_SIZE_K = block_k,
+                num_warps = num_warps,
+                num_stages = num_stages,
+                use_tma_load_w = tma_load_a,
+                use_tma_load_x = tma_load_b,
+                permute_x = permute_x,
+                permute_y = permute_y,
             )
         elif mode == "dW":
             kernel_config = KernelConfigBackward_dW(
-                BLOCK_SIZE_M=block_m,
-                BLOCK_SIZE_N=block_n,
-                BLOCK_SIZE_K=block_k,
-                num_warps=num_warps,
-                num_stages=num_stages,
-                use_tma_load_dy=tma_load_a,
-                use_tma_load_x=tma_load_b,
-                permute_x=permute_x,
-                permute_y=permute_y,
+                BLOCK_SIZE_M = block_m,
+                BLOCK_SIZE_N = block_n,
+                BLOCK_SIZE_K = block_k,
+                num_warps = num_warps,
+                num_stages = num_stages,
+                use_tma_load_dy = tma_load_a,
+                use_tma_load_x = tma_load_b,
+                permute_x = permute_x,
+                permute_y = permute_y,
             )
         elif mode == "dX":
             kernel_config = KernelConfigBackward_dX(
-                BLOCK_SIZE_M=block_m,
-                BLOCK_SIZE_N=block_n,
-                BLOCK_SIZE_K=block_k,
-                num_warps=num_warps,
-                num_stages=num_stages,
-                use_tma_load_dy=tma_load_a,
-                use_tma_load_w=tma_load_b,
-                permute_x=permute_x,
-                permute_y=permute_y,
+                BLOCK_SIZE_M = block_m,
+                BLOCK_SIZE_N = block_n,
+                BLOCK_SIZE_K = block_k,
+                num_warps = num_warps,
+                num_stages = num_stages,
+                use_tma_load_dy = tma_load_a,
+                use_tma_load_w = tma_load_b,
+                permute_x = permute_x,
+                permute_y = permute_y,
             )
         else:
             raise ValueError(f"Invalid mode: {mode}")
@@ -161,7 +161,7 @@ def power_of_two_range(start, end):
     return [2**i for i in range(int(start), int(end) + 1)]
 
 
-def multiples_of_range(start, end, step=1):
+def multiples_of_range(start, end, step = 1):
     return list(range(start, end + step, step))
 
 
@@ -221,8 +221,8 @@ def postprocess_autotune_results(autotuner, mode, ref_time, fused_time, results_
         print(f"{mode} {key}: {value.all_kwargs()}")
     save_autotune_results(
         autotuner.cache,
-        mode=mode,
-        ref_time=ref_time,
-        fused_time=fused_time,
-        results_dir=results_dir,
+        mode = mode,
+        ref_time = ref_time,
+        fused_time = fused_time,
+        results_dir = results_dir,
     )
diff --git a/unsloth/kernels/moe/grouped_gemm/interface.py b/unsloth/kernels/moe/grouped_gemm/interface.py
index 5d1a8865e..d278d8964 100644
--- a/unsloth/kernels/moe/grouped_gemm/interface.py
+++ b/unsloth/kernels/moe/grouped_gemm/interface.py
@@ -60,7 +60,7 @@ def alloc_fn(size: int, alignment: int, stream):
                 or _per_stream_tensors[stream].numel() < size
             ):
                 _per_stream_tensors[stream] = torch.empty(
-                    size, device=device, dtype=torch.int8
+                    size, device = device, dtype = torch.int8
                 )
                 _per_stream_tensors[stream].__hibernate__ = {"type": "ignore"}
             return _per_stream_tensors[stream]
@@ -160,7 +160,7 @@ def grouped_gemm_forward(
     if use_tma or autotune:
 
         def alloc_fn(size: int, alignment: int, stream: int):
-            return torch.empty(size, device="cuda", dtype=torch.int8)
+            return torch.empty(size, device = "cuda", dtype = torch.int8)
 
         triton.set_allocator(alloc_fn)
 
@@ -168,22 +168,22 @@ def alloc_fn(size: int, alignment: int, stream: int):
     W = W.view(-1, W.shape[-1])
 
     if permute_x or permute_y:
-        assert gather_indices is not None, (
-            "gather_indices must be provided when permute_x or permute_y is True"
-        )
+        assert (
+            gather_indices is not None
+        ), "gather_indices must be provided when permute_x or permute_y is True"
         assert gather_indices.is_contiguous()
         assert gather_indices.device.type == "cuda"
         assert gather_indices.ndim == 1
         total_tokens = gather_indices.shape[0]
         num_tokens = total_tokens // topk
         if permute_x:
-            assert X.shape[0] == num_tokens, (
-                f"X.shape[0] ({X.shape[0]}) must match num_tokens ({num_tokens})"
-            )
+            assert (
+                X.shape[0] == num_tokens
+            ), f"X.shape[0] ({X.shape[0]}) must match num_tokens ({num_tokens})"
         else:
-            assert X.shape[0] == total_tokens, (
-                f"X.shape[0] ({X.shape[0]}) must match total_tokens ({total_tokens})"
-            )
+            assert (
+                X.shape[0] == total_tokens
+            ), f"X.shape[0] ({X.shape[0]}) must match total_tokens ({total_tokens})"
     else:
         total_tokens = X.shape[0]
         num_tokens = total_tokens // topk
@@ -211,7 +211,7 @@ def alloc_fn(size: int, alignment: int, stream: int):
                 f"DEBUG::GROUPED_GEMM {topk_weights.tolist()} {gather_indices.tolist()}"
             )
 
-    y = torch.empty((total_tokens, N), device=X.device, dtype=X.dtype)
+    y = torch.empty((total_tokens, N), device = X.device, dtype = X.dtype)
     if total_tokens == 0 or N == 0:
         return y
 
@@ -227,7 +227,7 @@ def grid(META):
 
     if debug:
         print(
-            f"DEBUG::GROUPED_GEMM {num_tokens=} {topk=} {num_experts=} {N=} {K=} {BLOCK_SIZE_M=} {BLOCK_SIZE_N=} {BLOCK_SIZE_K=} {permute_x=}"
+            f"DEBUG::GROUPED_GEMM {num_tokens = } {topk = } {num_experts = } {N = } {K = } {BLOCK_SIZE_M = } {BLOCK_SIZE_N = } {BLOCK_SIZE_K = } {permute_x = }"
         )
         print(
             f"DEBUG::GROUPED_GEMM {m_sizes.tolist()} {(gather_indices // topk).tolist()}"
@@ -328,12 +328,12 @@ def grouped_gemm_dX(
     use_tma_load_w: use TMA for loading weights.  If TMA supported, this should always be enabled as it is faster than global memory load.
     use_tma_store: use TMA for storing dX.  Incompatible with permute_x.  TODO: add TMA gather / scatter support for Blackwell+ which will enable permute_x and use_tma_store.
     """
-    assert not fuse_mul_pre, (
-        "fuse_mul_pre should only be used for inference, not for training"
-    )
-    assert not fuse_mul_post, (
-        "fuse_mul_post should only be used for inference, not for training"
-    )
+    assert (
+        not fuse_mul_pre
+    ), "fuse_mul_pre should only be used for inference, not for training"
+    assert (
+        not fuse_mul_post
+    ), "fuse_mul_post should only be used for inference, not for training"
     assert dY.is_contiguous()
     assert W.is_contiguous()
     assert m_sizes.is_contiguous()
@@ -357,7 +357,7 @@ def grouped_gemm_dX(
 
         def alloc_fn(size: int, alignment: int, stream: int):
             # print(f"DEBUG::GROUPED_GEMM alloc_fn {size=} {alignment=} {stream=}")
-            return torch.empty(size, device="cuda", dtype=torch.int8)
+            return torch.empty(size, device = "cuda", dtype = torch.int8)
 
         triton.set_allocator(alloc_fn)
 
@@ -370,20 +370,20 @@ def alloc_fn(size: int, alignment: int, stream: int):
     N = N_total // num_experts
     assert N_grad == N, f"Grad_output N ({N_grad}) must match weight N ({N})"
 
-    assert M_total % topk == 0, (
-        f"M_total ({M_total}) must be divisible by topk ({topk})"
-    )
+    assert (
+        M_total % topk == 0
+    ), f"M_total ({M_total}) must be divisible by topk ({topk})"
     num_tokens = M_total // topk
 
     total_tokens = gather_indices.shape[0]
-    assert total_tokens == M_total, (
-        f"Total tokens ({total_tokens}) must match M_total ({M_total})"
-    )
+    assert (
+        total_tokens == M_total
+    ), f"Total tokens ({total_tokens}) must match M_total ({M_total})"
 
     # Note that the output shape is [NUM_TOKENS * TOPK, K] even when `permute_x` is True since we need to accumulate gradients across all experts chosen by the token.
     # This will be done in a post-processing step reduction step.
     output_shape = (total_tokens, K)
-    dX = torch.zeros(output_shape, device=dY.device, dtype=dY.dtype)
+    dX = torch.zeros(output_shape, device = dY.device, dtype = dY.dtype)
 
     NUM_SMS = torch.cuda.get_device_properties(
         "cuda"
@@ -399,7 +399,7 @@ def grid(META):
 
     if debug:
         print(
-            f"DEBUG::GROUPED_GEMM {num_tokens=} {topk=} {output_shape=} {num_experts=} {N=} {K=} {BLOCK_SIZE_M=} {BLOCK_SIZE_N=} {BLOCK_SIZE_K=} {NUM_SMS=}"
+            f"DEBUG::GROUPED_GEMM {num_tokens = } {topk = } {output_shape = } {num_experts = } {N = } {K = } {BLOCK_SIZE_M = } {BLOCK_SIZE_N = } {BLOCK_SIZE_K = } {NUM_SMS = }"
         )
         print(f"DEBUG::GROUPED_GEMM {m_sizes.tolist()}")
 
@@ -512,7 +512,7 @@ def grouped_gemm_dW(
     if use_tma or autotune:
 
         def alloc_fn(size: int, alignment: int, stream: int):
-            return torch.empty(size, device="cuda", dtype=torch.int8)
+            return torch.empty(size, device = "cuda", dtype = torch.int8)
 
         triton.set_allocator(alloc_fn)
 
@@ -538,7 +538,7 @@ def alloc_fn(size: int, alignment: int, stream: int):
 
     assert M_grad == total_tokens, f"dY M ({M_grad}) != total_tokens ({total_tokens})"
 
-    dW = torch.zeros((num_experts, N, K), device=X.device, dtype=X.dtype)
+    dW = torch.zeros((num_experts, N, K), device = X.device, dtype = X.dtype)
 
     if not autotune:
         BLOCK_SIZE_M = min(total_tokens, BLOCK_SIZE_M)
@@ -550,11 +550,11 @@ def grid(META):
 
     if debug:
         print(
-            f"DEBUG::GROUPED_GEMM_DW_TMA {num_experts=} {N=} {K=} {BLOCK_SIZE_M=} {BLOCK_SIZE_N=} {BLOCK_SIZE_K=} {NUM_SMS=}"
+            f"DEBUG::GROUPED_GEMM_DW_TMA {num_experts = } {N = } {K = } {BLOCK_SIZE_M = } {BLOCK_SIZE_N = } {BLOCK_SIZE_K = } {NUM_SMS = }"
         )
 
-        print(f"DEBUG::GROUPED_GEMM_DW_TMA {m_sizes.tolist()=}")
-        print(f"DEBUG::GROUPED_GEMM_DW_TMA {gather_indices.tolist()=}")
+        print(f"DEBUG::GROUPED_GEMM_DW_TMA {m_sizes.tolist() = }")
+        print(f"DEBUG::GROUPED_GEMM_DW_TMA {gather_indices.tolist() = }")
         m_start = 0
         for i in range(num_experts):
             expert_token_idx = gather_indices[m_start : m_start + m_sizes[i]]
@@ -663,17 +663,17 @@ def forward(
             fwd_config["use_tma_store"] = kernel_config_fwd.use_tma_store
 
         return grouped_gemm_forward(
-            X=X,
-            W=W,
-            topk=topk,
-            m_sizes=m_sizes,
-            gather_indices=gather_indices,
-            topk_weights=topk_weights,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            fuse_mul_post=fuse_mul_post,
+            X = X,
+            W = W,
+            topk = topk,
+            m_sizes = m_sizes,
+            gather_indices = gather_indices,
+            topk_weights = topk_weights,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            fuse_mul_post = fuse_mul_post,
             # Autotune -- this will override the manual kernel config if true
-            autotune=autotune,
+            autotune = autotune,
             # Manual kernel config
             **fwd_config,
         )
@@ -693,17 +693,17 @@ def backward(ctx, dY):
 
         if not autotune:
             if not dW_only:
-                assert kernel_config_bwd_dX is not None, (
-                    "kernel_config_bwd_dX must be provided if autotune is False"
-                )
+                assert (
+                    kernel_config_bwd_dX is not None
+                ), "kernel_config_bwd_dX must be provided if autotune is False"
             if not dX_only:
-                assert kernel_config_bwd_dW is not None, (
-                    "kernel_config_bwd_dW must be provided if autotune is False"
-                )
+                assert (
+                    kernel_config_bwd_dW is not None
+                ), "kernel_config_bwd_dW must be provided if autotune is False"
 
-        assert not fuse_mul_post, (
-            "fused_mul should only be used for inference, not for training"
-        )
+        assert (
+            not fuse_mul_post
+        ), "fused_mul should only be used for inference, not for training"
 
         if not dX_only:
             bwd_dW_config = {}
@@ -719,15 +719,15 @@ def backward(ctx, dY):
                 bwd_dW_config["num_stages"] = kernel_config_bwd_dW.num_stages
 
             dW = grouped_gemm_dW(
-                X=X,
-                dY=dY,
-                m_sizes=m_sizes,
-                gather_indices=gather_indices,
-                topk=topk,
-                permute_x=permute_x,
-                permute_y=permute_y,
+                X = X,
+                dY = dY,
+                m_sizes = m_sizes,
+                gather_indices = gather_indices,
+                topk = topk,
+                permute_x = permute_x,
+                permute_y = permute_y,
                 # Autotune -- this will override the manual kernel config if true
-                autotune=autotune,
+                autotune = autotune,
                 # Manual kernel config
                 **bwd_dW_config,
             )
@@ -747,21 +747,21 @@ def backward(ctx, dY):
                 bwd_dX_config["num_stages"] = kernel_config_bwd_dX.num_stages
 
             dX = grouped_gemm_dX(
-                dY=dY,
-                W=W,
-                m_sizes=m_sizes,
-                gather_indices=gather_indices,
-                topk=topk,
-                permute_x=permute_x,
-                permute_y=permute_y,
+                dY = dY,
+                W = W,
+                m_sizes = m_sizes,
+                gather_indices = gather_indices,
+                topk = topk,
+                permute_x = permute_x,
+                permute_y = permute_y,
                 # Autotune -- this will override the manual kernel config if true
-                autotune=autotune,
+                autotune = autotune,
                 # Manual kernel config
                 **bwd_dX_config,
             )
 
             if topk > 1 and permute_x:
-                dX = dX.view(X.shape[0], topk, -1).sum(dim=1)
+                dX = dX.view(X.shape[0], topk, -1).sum(dim = 1)
         else:
             dX = None
 
@@ -799,21 +799,21 @@ def check_valid_config_fwd(
     is_second_gemm = not is_first_gemm
 
     assert not (permute_x and permute_y), "Cannot permute both X and Y"
-    assert not (is_second_gemm and permute_x), (
-        "Cannot permute X for the second grouped GEMM"
-    )
-    assert not (is_first_gemm and permute_y), (
-        "Cannot permute Y for the first grouped GEMM"
-    )
-    assert not (fuse_mul_post and is_first_gemm), (
-        "Cannot fuse mul for the first grouped GEMM"
-    )
-    assert not (use_tma_load_x and permute_x), (
-        "Cannot use TMA load and permute X unless on sm100+ (Blackwell+)"
-    )
-    assert not (use_tma_store and permute_y and is_second_gemm), (
-        "Cannot use TMA store and permute Y for the second grouped GEMM unless on sm100+ (Blackwell+)"
-    )
+    assert not (
+        is_second_gemm and permute_x
+    ), "Cannot permute X for the second grouped GEMM"
+    assert not (
+        is_first_gemm and permute_y
+    ), "Cannot permute Y for the first grouped GEMM"
+    assert not (
+        fuse_mul_post and is_first_gemm
+    ), "Cannot fuse mul for the first grouped GEMM"
+    assert not (
+        use_tma_load_x and permute_x
+    ), "Cannot use TMA load and permute X unless on sm100+ (Blackwell+)"
+    assert not (
+        use_tma_store and permute_y and is_second_gemm
+    ), "Cannot use TMA store and permute Y for the second grouped GEMM unless on sm100+ (Blackwell+)"
 
 
 def check_valid_config_bwd_dW(
@@ -866,8 +866,8 @@ def grouped_gemm(
     gather_indices: torch.Tensor = None,
     permute_x: bool = False,
     permute_y: bool = False,
-    topk_weights=None,
-    fuse_mul_post=False,
+    topk_weights = None,
+    fuse_mul_post = False,
     kernel_config_fwd: KernelConfigForward = None,
     kernel_config_bwd_dX: KernelConfigBackward_dX = None,
     kernel_config_bwd_dW: KernelConfigBackward_dW = None,
@@ -901,49 +901,49 @@ def grouped_gemm(
 
     """
     if not autotune:
-        assert kernel_config_fwd is not None, (
-            "kernel_config_fwd must be provided if autotune is False"
-        )
+        assert (
+            kernel_config_fwd is not None
+        ), "kernel_config_fwd must be provided if autotune is False"
 
         check_valid_config_fwd(
             permute_x,
             permute_y,
-            use_tma_load_x=kernel_config_fwd.use_tma_load_x,
-            use_tma_load_w=kernel_config_fwd.use_tma_load_w,
-            use_tma_store=kernel_config_fwd.use_tma_store,
-            fuse_mul_post=fuse_mul_post,
-            is_first_gemm=is_first_gemm,
+            use_tma_load_x = kernel_config_fwd.use_tma_load_x,
+            use_tma_load_w = kernel_config_fwd.use_tma_load_w,
+            use_tma_store = kernel_config_fwd.use_tma_store,
+            fuse_mul_post = fuse_mul_post,
+            is_first_gemm = is_first_gemm,
         )
         if kernel_config_bwd_dW is not None and not dX_only:
             check_valid_config_bwd_dW(
                 permute_x,
                 permute_y,
-                use_tma_load_dY=kernel_config_bwd_dW.use_tma_load_dy,
-                use_tma_load_x=kernel_config_bwd_dW.use_tma_load_x,
-                use_tma_store=kernel_config_bwd_dW.use_tma_store,
-                fuse_mul_post=fuse_mul_post,
-                is_first_gemm=is_first_gemm,
+                use_tma_load_dY = kernel_config_bwd_dW.use_tma_load_dy,
+                use_tma_load_x = kernel_config_bwd_dW.use_tma_load_x,
+                use_tma_store = kernel_config_bwd_dW.use_tma_store,
+                fuse_mul_post = fuse_mul_post,
+                is_first_gemm = is_first_gemm,
             )
         if kernel_config_bwd_dX is not None and not dW_only:
             check_valid_config_bwd_dX(
                 permute_x,
                 permute_y,
-                use_tma_load_dY=kernel_config_bwd_dX.use_tma_load_dy,
-                use_tma_load_w=kernel_config_bwd_dX.use_tma_load_w,
-                use_tma_store=kernel_config_bwd_dX.use_tma_store,
-                fuse_mul_post=fuse_mul_post,
-                is_first_gemm=is_first_gemm,
+                use_tma_load_dY = kernel_config_bwd_dX.use_tma_load_dy,
+                use_tma_load_w = kernel_config_bwd_dX.use_tma_load_w,
+                use_tma_store = kernel_config_bwd_dX.use_tma_store,
+                fuse_mul_post = fuse_mul_post,
+                is_first_gemm = is_first_gemm,
             )
 
     if permute_x or permute_y:
-        assert gather_indices is not None, (
-            "gather_indices is required when either permute_x or permute_y is True"
-        )
+        assert (
+            gather_indices is not None
+        ), "gather_indices is required when either permute_x or permute_y is True"
 
     if fuse_mul_post:
-        assert topk_weights is not None, (
-            "topk_weights is required when fuse_mul_post is True"
-        )
+        assert (
+            topk_weights is not None
+        ), "topk_weights is required when fuse_mul_post is True"
 
     X = X.view(-1, X.shape[-1])
     m_sizes = m_sizes.view(-1)
diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py b/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py
index d57105cee..a185b5fd3 100644
--- a/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py
+++ b/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py
@@ -37,15 +37,15 @@ def convert_args_to_list(args):
 
 
 def get_forward_configs(
-    BLOCK_M=DEFAULT_M_BLOCK_SIZES,
-    BLOCK_N=DEFAULT_N_BLOCK_SIZES,
-    BLOCK_K=DEFAULT_K_BLOCK_SIZES,
-    TMA_LOAD_X=True,
-    TMA_LOAD_W=True,
-    TMA_STORE=False,  # NOTE: TMA_STORE is disabled for now
-    num_warps=DEFAULT_NUM_WARPS,
-    num_stages=DEFAULT_NUM_STAGES,
-    num_ctas=DEFAULT_NUM_CTAS,
+    BLOCK_M = DEFAULT_M_BLOCK_SIZES,
+    BLOCK_N = DEFAULT_N_BLOCK_SIZES,
+    BLOCK_K = DEFAULT_K_BLOCK_SIZES,
+    TMA_LOAD_X = True,
+    TMA_LOAD_W = True,
+    TMA_STORE = False,  # NOTE: TMA_STORE is disabled for now
+    num_warps = DEFAULT_NUM_WARPS,
+    num_stages = DEFAULT_NUM_STAGES,
+    num_ctas = DEFAULT_NUM_CTAS,
 ):
     (
         BLOCK_M,
@@ -95,16 +95,16 @@ def get_forward_configs(
         kernel_configs.append(
             triton.Config(
                 dict(
-                    BLOCK_SIZE_M=block_m,
-                    BLOCK_SIZE_N=block_n,
-                    BLOCK_SIZE_K=block_k,
-                    USE_TMA_LOAD_X=tma_load_x,
-                    USE_TMA_LOAD_W=tma_load_w,
-                    USE_TMA_STORE=tma_store,
+                    BLOCK_SIZE_M = block_m,
+                    BLOCK_SIZE_N = block_n,
+                    BLOCK_SIZE_K = block_k,
+                    USE_TMA_LOAD_X = tma_load_x,
+                    USE_TMA_LOAD_W = tma_load_w,
+                    USE_TMA_STORE = tma_store,
                 ),
-                num_warps=w,
-                num_stages=s,
-                num_ctas=num_ctas,
+                num_warps = w,
+                num_stages = s,
+                num_ctas = num_ctas,
             )
         )
 
@@ -112,15 +112,15 @@ def get_forward_configs(
 
 
 def get_dX_kernel_configs(
-    BLOCK_M=DEFAULT_M_BLOCK_SIZES,
-    BLOCK_N=DEFAULT_N_BLOCK_SIZES,
-    BLOCK_K=DEFAULT_K_BLOCK_SIZES,
-    TMA_LOAD_dY=True,
-    TMA_LOAD_W=True,
-    TMA_STORE=False,  # NOTE: TMA_STORE is disabled for now
-    num_warps=DEFAULT_NUM_WARPS,
-    num_stages=DEFAULT_NUM_STAGES,
-    num_ctas=DEFAULT_NUM_CTAS,
+    BLOCK_M = DEFAULT_M_BLOCK_SIZES,
+    BLOCK_N = DEFAULT_N_BLOCK_SIZES,
+    BLOCK_K = DEFAULT_K_BLOCK_SIZES,
+    TMA_LOAD_dY = True,
+    TMA_LOAD_W = True,
+    TMA_STORE = False,  # NOTE: TMA_STORE is disabled for now
+    num_warps = DEFAULT_NUM_WARPS,
+    num_stages = DEFAULT_NUM_STAGES,
+    num_ctas = DEFAULT_NUM_CTAS,
 ):
     (
         BLOCK_M,
@@ -170,16 +170,16 @@ def get_dX_kernel_configs(
         kernel_configs.append(
             triton.Config(
                 dict(
-                    BLOCK_SIZE_M=block_m,
-                    BLOCK_SIZE_N=block_n,
-                    BLOCK_SIZE_K=block_k,
-                    USE_TMA_LOAD_dY=tma_load_dy,
-                    USE_TMA_LOAD_W=tma_load_w,
-                    USE_TMA_STORE=tma_store,
+                    BLOCK_SIZE_M = block_m,
+                    BLOCK_SIZE_N = block_n,
+                    BLOCK_SIZE_K = block_k,
+                    USE_TMA_LOAD_dY = tma_load_dy,
+                    USE_TMA_LOAD_W = tma_load_w,
+                    USE_TMA_STORE = tma_store,
                 ),
-                num_warps=w,
-                num_stages=s,
-                num_ctas=num_ctas,
+                num_warps = w,
+                num_stages = s,
+                num_ctas = num_ctas,
             )
         )
 
@@ -187,15 +187,15 @@ def get_dX_kernel_configs(
 
 
 def get_dW_kernel_configs(
-    BLOCK_M=DEFAULT_M_BLOCK_SIZES,
-    BLOCK_N=DEFAULT_N_BLOCK_SIZES,
-    BLOCK_K=DEFAULT_K_BLOCK_SIZES,
-    num_warps=DEFAULT_NUM_WARPS,
-    num_stages=DEFAULT_NUM_STAGES,
-    num_ctas=DEFAULT_NUM_CTAS,
-    TMA_LOAD_dY=True,
-    TMA_LOAD_X=True,
-    TMA_STORE=False,
+    BLOCK_M = DEFAULT_M_BLOCK_SIZES,
+    BLOCK_N = DEFAULT_N_BLOCK_SIZES,
+    BLOCK_K = DEFAULT_K_BLOCK_SIZES,
+    num_warps = DEFAULT_NUM_WARPS,
+    num_stages = DEFAULT_NUM_STAGES,
+    num_ctas = DEFAULT_NUM_CTAS,
+    TMA_LOAD_dY = True,
+    TMA_LOAD_X = True,
+    TMA_STORE = False,
 ):
     (
         BLOCK_M,
@@ -245,16 +245,16 @@ def get_dW_kernel_configs(
         kernel_configs.append(
             triton.Config(
                 dict(
-                    BLOCK_SIZE_M=block_m,
-                    BLOCK_SIZE_N=block_n,
-                    BLOCK_SIZE_K=block_k,
-                    USE_TMA_LOAD_dY=tma_load_dy,
-                    USE_TMA_LOAD_X=tma_load_x,
-                    USE_TMA_STORE=tma_store,
+                    BLOCK_SIZE_M = block_m,
+                    BLOCK_SIZE_N = block_n,
+                    BLOCK_SIZE_K = block_k,
+                    USE_TMA_LOAD_dY = tma_load_dy,
+                    USE_TMA_LOAD_X = tma_load_x,
+                    USE_TMA_STORE = tma_store,
                 ),
-                num_warps=w,
-                num_stages=s,
-                num_ctas=num_ctas,
+                num_warps = w,
+                num_stages = s,
+                num_ctas = num_ctas,
             )
         )
 
diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/backward.py b/unsloth/kernels/moe/grouped_gemm/kernels/backward.py
index a05fb4d5d..d8bdcb57e 100644
--- a/unsloth/kernels/moe/grouped_gemm/kernels/backward.py
+++ b/unsloth/kernels/moe/grouped_gemm/kernels/backward.py
@@ -84,18 +84,18 @@ def _grouped_gemm_dX_kernel(
     if USE_TMA_LOAD_dY:
         dY_desc = tl._experimental_make_tensor_descriptor(
             dY_ptr,
-            shape=[TOTAL_TOKENS, N],
-            strides=[N, 1],
-            block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
+            shape = [TOTAL_TOKENS, N],
+            strides = [N, 1],
+            block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
         )
 
     if USE_TMA_LOAD_W:
         expert_stride = N * K
         w_desc = tl._experimental_make_tensor_descriptor(
             w_ptr,
-            shape=[NUM_EXPERTS, N, K],
-            strides=[expert_stride, K, 1],
-            block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K],
+            shape = [NUM_EXPERTS, N, K],
+            strides = [expert_stride, K, 1],
+            block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
         )
 
     m_end = 0
@@ -104,7 +104,7 @@ def _grouped_gemm_dX_kernel(
     n_block_range = tl.arange(0, BLOCK_SIZE_N)
     k_block_range = tl.arange(0, BLOCK_SIZE_K)
 
-    for expert_idx in range(NUM_EXPERTS, flatten=FLATTEN):
+    for expert_idx in range(NUM_EXPERTS, flatten = FLATTEN):
         m_start = m_end
         m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)
         m_end = m_start + m_size
@@ -125,9 +125,9 @@ def _grouped_gemm_dX_kernel(
                 )
                 dX_desc = tl._experimental_make_tensor_descriptor(
                     dX_ptr,
-                    shape=[m_end, K],
-                    strides=[K, 1],
-                    block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
+                    shape = [m_end, K],
+                    strides = [K, 1],
+                    block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
                 )
 
             # Lower bound and upper bound are defined relative to the total tiles processed so far
@@ -152,7 +152,7 @@ def _grouped_gemm_dX_kernel(
                     )
                     expert_token_idx = tl.load(
                         gather_indices_ptr + indices_to_gather,
-                        mask=indices_to_gather < TOTAL_TOKENS,
+                        mask = indices_to_gather < TOTAL_TOKENS,
                     )
                     expert_token_offsets = expert_token_idx[:, None]
 
@@ -210,13 +210,13 @@ def _grouped_gemm_dX_kernel(
                 # col_mask = offs_bk[None, :] < K
                 store_mask = row_mask  # & col_mask
 
-                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
+                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype = tl.float32)
 
                 # GEMM main loop
                 for n_offset in range(0, N, BLOCK_SIZE_N):
                     # dY block [M, N]
                     if not USE_TMA_LOAD_dY:
-                        dY = tl.load(dY_ptrs, mask=row_mask)
+                        dY = tl.load(dY_ptrs, mask = row_mask)
                     else:
                         dY = dY_desc.load(
                             [m_start + tile_m_idx * BLOCK_SIZE_M, n_offset]
@@ -253,7 +253,7 @@ def _grouped_gemm_dX_kernel(
                     tl.store(
                         dX_ptr + store_idx + offs_bk[None, :],
                         dX,
-                        mask=store_mask,
+                        mask = store_mask,
                     )
 
                 # Move to the next tile within this expert group
@@ -264,9 +264,9 @@ def _grouped_gemm_dX_kernel(
 
 
 _autotuned_grouped_gemm_dX_kernel = triton.autotune(
-    configs=get_dX_kernel_configs(),
-    prune_configs_by={"early_config_prune": prune_dX_configs},
-    key=["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
+    configs = get_dX_kernel_configs(),
+    prune_configs_by = {"early_config_prune": prune_dX_configs},
+    key = ["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
 )(_grouped_gemm_dX_kernel)
 
 """
@@ -324,17 +324,17 @@ def _grouped_gemm_dW_kernel(
     if USE_TMA_LOAD_dY and not TMA_LOAD_BOTH:
         dY_desc = tl._experimental_make_tensor_descriptor(
             dY_ptr,
-            shape=[TOTAL_TOKENS, N],
-            strides=[N, 1],
-            block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
+            shape = [TOTAL_TOKENS, N],
+            strides = [N, 1],
+            block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
         )
 
     if USE_TMA_LOAD_X and not TMA_LOAD_BOTH:
         x_desc = tl._experimental_make_tensor_descriptor(
             x_ptr,
-            shape=[TOTAL_TOKENS, K],
-            strides=[K, 1],
-            block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
+            shape = [TOTAL_TOKENS, K],
+            strides = [K, 1],
+            block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
         )
     # Output tiles per expert, since each expert weight matrix is [N, K]
     num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
@@ -351,9 +351,9 @@ def _grouped_gemm_dW_kernel(
         tl.static_assert(K % BLOCK_SIZE_K == 0, "K must be divisible by BLOCK_SIZE_K")
         dW_desc = tl._experimental_make_tensor_descriptor(
             dW_ptr,
-            shape=[NUM_EXPERTS, N, K],
-            strides=[N * K, K, 1],
-            block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K],
+            shape = [NUM_EXPERTS, N, K],
+            strides = [N * K, K, 1],
+            block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
         )
 
     for tile_idx in range(
@@ -377,7 +377,7 @@ def _grouped_gemm_dW_kernel(
         m_end = 0
         for expert_idx in range(NUM_EXPERTS):
             # We need to instantiate a fresh accumulator for each expert
-            accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=acc_dtype)
+            accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype = acc_dtype)
 
             m_start = m_end
             # Need to figure out why this cast is needed, otherwise compiler complains about mismatching types
@@ -392,16 +392,16 @@ def _grouped_gemm_dW_kernel(
                 if TMA_LOAD_BOTH:
                     dY_desc = tl._experimental_make_tensor_descriptor(
                         dY_ptr,
-                        shape=[m_end, N],
-                        strides=[N, 1],
-                        block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
+                        shape = [m_end, N],
+                        strides = [N, 1],
+                        block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
                     )
 
                     x_desc = tl._experimental_make_tensor_descriptor(
                         x_ptr,
-                        shape=[m_end, K],
-                        strides=[K, 1],
-                        block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
+                        shape = [m_end, K],
+                        strides = [K, 1],
+                        block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
                     )
 
                 for tile_m_idx in range(0, m_size, BLOCK_SIZE_M):
@@ -425,7 +425,7 @@ def _grouped_gemm_dW_kernel(
                             # indices_to_gather = m_start + gather_offsets
                             expert_token_idx = tl.load(
                                 gather_indices_ptr + indices_to_gather,
-                                mask=indices_to_gather < TOTAL_TOKENS,
+                                mask = indices_to_gather < TOTAL_TOKENS,
                             )
                             expert_token_offsets = expert_token_idx[:, None]
 
@@ -461,7 +461,7 @@ def _grouped_gemm_dW_kernel(
                                 x_ptr
                                 + x_row_load_idx
                                 + (k_offset + block_range_k)[None, :],
-                                mask=mk_mask,
+                                mask = mk_mask,
                             )
 
                         if USE_TMA_LOAD_dY:
@@ -471,7 +471,7 @@ def _grouped_gemm_dW_kernel(
                                 dY_ptr
                                 + dY_row_load_idx
                                 + (n_offset + block_range_n)[None, :],
-                                mask=mn_mask,
+                                mask = mn_mask,
                             )
 
                         accumulator += tl.dot(
@@ -491,12 +491,12 @@ def _grouped_gemm_dW_kernel(
                         + store_row_offs[:, None] * K
                         + (k_offset + block_range_k)[None, :],
                         y,
-                        mask=nk_mask,
+                        mask = nk_mask,
                     )
 
 
 _autotuned_grouped_gemm_dW_kernel = triton.autotune(
-    configs=get_dW_kernel_configs(),
-    prune_configs_by={"early_config_prune": prune_kernel_configs_backward_dW},
-    key=["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
+    configs = get_dW_kernel_configs(),
+    prune_configs_by = {"early_config_prune": prune_kernel_configs_backward_dW},
+    key = ["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
 )(_grouped_gemm_dW_kernel)
diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/forward.py b/unsloth/kernels/moe/grouped_gemm/kernels/forward.py
index 98f665029..f84694e91 100644
--- a/unsloth/kernels/moe/grouped_gemm/kernels/forward.py
+++ b/unsloth/kernels/moe/grouped_gemm/kernels/forward.py
@@ -68,25 +68,25 @@ def _grouped_gemm_forward_kernel(
     if USE_TMA_LOAD_X:
         x_desc = tl._experimental_make_tensor_descriptor(
             x_ptr,
-            shape=[TOTAL_TOKENS, K],
-            strides=[K, 1],
-            block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
+            shape = [TOTAL_TOKENS, K],
+            strides = [K, 1],
+            block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
         )
 
     if USE_TMA_LOAD_W:
         expert_stride = N * K
         w_desc = tl._experimental_make_tensor_descriptor(
             w_ptr,
-            shape=[NUM_EXPERTS, N, K],
-            strides=[expert_stride, K, 1],
-            block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K],
+            shape = [NUM_EXPERTS, N, K],
+            strides = [expert_stride, K, 1],
+            block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
         )
 
     m_end = 0
     processed_tiles = 0
     m_block_range = tl.arange(0, BLOCK_SIZE_M)
 
-    for expert_idx in tl.range(NUM_EXPERTS, flatten=FLATTEN):
+    for expert_idx in tl.range(NUM_EXPERTS, flatten = FLATTEN):
         m_start = m_end
         m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)
         m_end = m_start + m_size
@@ -102,9 +102,9 @@ def _grouped_gemm_forward_kernel(
             if USE_TMA_STORE:
                 y_desc = tl._experimental_make_tensor_descriptor(
                     y_ptr,  # + m_start * N,
-                    shape=[m_end, N],
-                    strides=[N, 1],
-                    block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
+                    shape = [m_end, N],
+                    strides = [N, 1],
+                    block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
                 )
 
             # Process tiles for this expert
@@ -127,7 +127,7 @@ def _grouped_gemm_forward_kernel(
                     )
                     expert_token_idx = tl.load(
                         gather_indices_ptr + indices_to_gather,
-                        mask=indices_to_gather < TOTAL_TOKENS,
+                        mask = indices_to_gather < TOTAL_TOKENS,
                     )
                     expert_token_offsets = expert_token_idx[:, None]
 
@@ -178,7 +178,7 @@ def _grouped_gemm_forward_kernel(
                 if SHOULD_FUSE_MUL:
                     topk_load_idx = expert_token_offsets
 
-                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
+                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype = acc_dtype)
 
                 offs_k = tl.arange(0, BLOCK_SIZE_K)
 
@@ -194,19 +194,19 @@ def _grouped_gemm_forward_kernel(
 
                 for k_offset in range(0, K, BLOCK_SIZE_K):
                     if not USE_TMA_LOAD_X:
-                        x = tl.load(x_ptrs, mask=row_mask)
+                        x = tl.load(x_ptrs, mask = row_mask)
                     else:
                         x = x_desc.load([m_start + off_am, k_offset])
 
                     if FUSE_MUL_PRE:
                         # Check for correct broadcasting
                         topk_weights = tl.load(
-                            topk_weights_ptr + topk_load_idx, mask=row_mask
+                            topk_weights_ptr + topk_load_idx, mask = row_mask
                         )
                         x *= topk_weights.to(x.dtype)
 
                     if not USE_TMA_LOAD_W:
-                        w = tl.load(w_ptrs, mask=offs_bn[:, None] < N)
+                        w = tl.load(w_ptrs, mask = offs_bn[:, None] < N)
                     else:
                         w = w_desc.load(
                             [expert_idx, tile_n_idx * BLOCK_SIZE_N, k_offset]
@@ -228,7 +228,7 @@ def _grouped_gemm_forward_kernel(
                 if FUSE_MUL_POST:
                     # Check for correct broadcasting
                     topk_weights = tl.load(
-                        topk_weights_ptr + topk_load_idx, mask=row_mask
+                        topk_weights_ptr + topk_load_idx, mask = row_mask
                     )
                     y *= topk_weights.to(output_dtype)
 
@@ -243,7 +243,7 @@ def _grouped_gemm_forward_kernel(
                     tl.store(
                         y_ptr + store_idx + offs_bn[None, :],
                         y,
-                        mask=store_mask,
+                        mask = store_mask,
                     )
                 tidx += NUM_SMS
 
@@ -251,9 +251,9 @@ def _grouped_gemm_forward_kernel(
 
 
 _autotuned_grouped_gemm_forward_kernel = triton.autotune(
-    configs=get_forward_configs(),
-    prune_configs_by={"early_config_prune": prune_kernel_configs_fwd},
-    key=[
+    configs = get_forward_configs(),
+    prune_configs_by = {"early_config_prune": prune_kernel_configs_fwd},
+    key = [
         "NUM_EXPERTS",
         "NUM_TOKENS",
         "N",
diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/tuning.py b/unsloth/kernels/moe/grouped_gemm/kernels/tuning.py
index 3632e811e..1f641478b 100644
--- a/unsloth/kernels/moe/grouped_gemm/kernels/tuning.py
+++ b/unsloth/kernels/moe/grouped_gemm/kernels/tuning.py
@@ -109,9 +109,9 @@ class KernelResult:
     def to_dict(self):
         return OrderedDict(
             **asdict(self.kernel_config),
-            torch_time=self.torch_time,
-            triton_time=self.triton_time,
-            speedup=self.speedup,
+            torch_time = self.torch_time,
+            triton_time = self.triton_time,
+            speedup = self.speedup,
         )
 
     @staticmethod
@@ -119,7 +119,7 @@ def to_dataframe(
         results: list["KernelResult"], sort_by: str = "speedup", ascending: bool = False
     ):
         df = pd.DataFrame([result.to_dict() for result in results])
-        df = df.sort_values(by=sort_by, ascending=ascending)
+        df = df.sort_values(by = sort_by, ascending = ascending)
         return df
 
     @staticmethod
@@ -130,7 +130,7 @@ def to_csv(
         filename: str = "results.csv",
     ):
         df = KernelResult.to_dataframe(results, sort_by, ascending)
-        df.to_csv(filename, index=False)
+        df.to_csv(filename, index = False)
 
     @staticmethod
     def print_table(
@@ -140,17 +140,17 @@ def print_table(
         num_results: int = 10,
     ):
         df = KernelResult.to_dataframe(results, sort_by, ascending)
-        print(df.head(num_results).to_string(index=False))
+        print(df.head(num_results).to_string(index = False))
 
 
 def get_kernel_configs(
-    BLOCK_M=DEFAULT_M_BLOCK_SIZES,
-    BLOCK_N=DEFAULT_N_BLOCK_SIZES,
-    BLOCK_K=DEFAULT_K_BLOCK_SIZES,
-    num_warps=DEFAULT_NUM_WARPS,
-    num_stages=DEFAULT_NUM_STAGES,
-    use_tma_loads=BOOLS,
-    fuse_permute=BOOLS,
+    BLOCK_M = DEFAULT_M_BLOCK_SIZES,
+    BLOCK_N = DEFAULT_N_BLOCK_SIZES,
+    BLOCK_K = DEFAULT_K_BLOCK_SIZES,
+    num_warps = DEFAULT_NUM_WARPS,
+    num_stages = DEFAULT_NUM_STAGES,
+    use_tma_loads = BOOLS,
+    fuse_permute = BOOLS,
 ):
     kernel_configs_fwd = []
     kernel_configs_backward_dW = []
@@ -160,44 +160,44 @@ def get_kernel_configs(
     ):
         kernel_configs_fwd.append(
             KernelConfigForward(
-                BLOCK_SIZE_M=block_m,
-                BLOCK_SIZE_N=block_n,
-                BLOCK_SIZE_K=block_k,
-                num_warps=w,
-                num_stages=s,
-                use_tma_load_x=use_tma_load,
-                use_tma_load_w=use_tma_load,
-                use_tma_store=False,
-                permute_x=permute,
-                permute_y=permute,
+                BLOCK_SIZE_M = block_m,
+                BLOCK_SIZE_N = block_n,
+                BLOCK_SIZE_K = block_k,
+                num_warps = w,
+                num_stages = s,
+                use_tma_load_x = use_tma_load,
+                use_tma_load_w = use_tma_load,
+                use_tma_store = False,
+                permute_x = permute,
+                permute_y = permute,
             )
         )
         kernel_configs_backward_dW.append(
             KernelConfigBackward_dW(
-                BLOCK_SIZE_M=block_m,
-                BLOCK_SIZE_N=block_n,
-                BLOCK_SIZE_K=block_k,
-                num_warps=w,
-                num_stages=s,
-                use_tma_load_dy=use_tma_load,
-                use_tma_load_x=use_tma_load,
-                use_tma_store=False,
-                permute_x=permute,
-                permute_y=permute,
+                BLOCK_SIZE_M = block_m,
+                BLOCK_SIZE_N = block_n,
+                BLOCK_SIZE_K = block_k,
+                num_warps = w,
+                num_stages = s,
+                use_tma_load_dy = use_tma_load,
+                use_tma_load_x = use_tma_load,
+                use_tma_store = False,
+                permute_x = permute,
+                permute_y = permute,
             )
         )
         kernel_configs_backward_dX.append(
             KernelConfigBackward_dX(
-                BLOCK_SIZE_M=block_m,
-                BLOCK_SIZE_N=block_n,
-                BLOCK_SIZE_K=block_k,
-                num_warps=w,
-                num_stages=s,
-                use_tma_load_dy=use_tma_load,
-                use_tma_load_w=use_tma_load,
-                use_tma_store=False,
-                permute_x=permute,
-                permute_y=permute,
+                BLOCK_SIZE_M = block_m,
+                BLOCK_SIZE_N = block_n,
+                BLOCK_SIZE_K = block_k,
+                num_warps = w,
+                num_stages = s,
+                use_tma_load_dy = use_tma_load,
+                use_tma_load_w = use_tma_load,
+                use_tma_store = False,
+                permute_x = permute,
+                permute_y = permute,
             )
         )
 
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py b/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py
index 3d1afac9e..4010c77ce 100644
--- a/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py
+++ b/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py
@@ -51,9 +51,9 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
     def __init__(
         self,
         config: Llama4TextConfig,
-        overlap_router_shared=False,
-        verbose=False,
-        debug=False,
+        overlap_router_shared = False,
+        verbose = False,
+        debug = False,
     ):
         super().__init__(config)
         self.overlap_router_shared = overlap_router_shared
@@ -62,9 +62,9 @@ def __init__(
 
         # Permute in-place expert weights
         E, K, N = self.num_experts, self.hidden_dim, self.experts.expert_dim
-        assert self.experts.gate_up_proj.shape == torch.Size([E, K, 2 * N]), (
-            f"{self.experts.gate_up_proj.shape} != {[E, K, 2 * N]}"
-        )
+        assert self.experts.gate_up_proj.shape == torch.Size(
+            [E, K, 2 * N]
+        ), f"{self.experts.gate_up_proj.shape} != {[E, K, 2 * N]}"
         permuted_shape = [E, 2 * N, K]
         permuted_stride = [2 * N * K, K, 1]
         if verbose:
@@ -79,9 +79,9 @@ def __init__(
                 f"{self.experts.gate_up_proj.shape}:{self.experts.gate_up_proj.stride()}"
             )
 
-        assert self.experts.down_proj.shape == torch.Size([E, N, K]), (
-            f"{self.experts.down_proj.shape} != {[E, N, K]}"
-        )
+        assert self.experts.down_proj.shape == torch.Size(
+            [E, N, K]
+        ), f"{self.experts.down_proj.shape} != {[E, N, K]}"
         permuted_shape = [E, K, N]
         permuted_stride = [K * N, N, 1]
         if verbose:
@@ -110,9 +110,9 @@ def copy_weights(self, other: Llama4TextMoe):
             if any(n in name for n in self.EXPERT_WEIGHT_NAMES):
                 param_to_copy = param_to_copy.permute(0, 2, 1)
 
-            assert param.shape == param_to_copy.shape, (
-                f"{param.shape} != {param_to_copy.shape}"
-            )
+            assert (
+                param.shape == param_to_copy.shape
+            ), f"{param.shape} != {param_to_copy.shape}"
             param.copy_(param_to_copy)
 
         return self
@@ -136,7 +136,7 @@ def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor:
         hidden_states = hidden_states.view(-1, self.hidden_dim)
         router_logits = self.router(hidden_states)
         routing_weights, selected_experts = torch.topk(
-            router_logits, self.top_k, dim=-1
+            router_logits, self.top_k, dim = -1
         )
 
         routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)
@@ -167,9 +167,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         router_logits, routing_weights, selected_experts = self.run_router(
             hidden_states
         )
-        assert routing_weights.shape == (num_tokens, self.top_k), (
-            f"{routing_weights.shape} != {(num_tokens, self.top_k)}"
-        )
+        assert routing_weights.shape == (
+            num_tokens,
+            self.top_k,
+        ), f"{routing_weights.shape} != {(num_tokens, self.top_k)}"
 
         if self.overlap_router_shared:
             with torch.cuda.stream(self.shared_expert_stream):
@@ -194,7 +195,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         )
 
         if self.top_k > 1:
-            hidden_states = hidden_states.sum(dim=1)
+            hidden_states = hidden_states.sum(dim = 1)
         hidden_states_after_weight_merge = hidden_states.view(-1, hidden_dim)
 
         # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
@@ -211,7 +212,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
         # Start expert computation
         first_gemm = torch_grouped_gemm(
-            X=hidden_states, W=self.experts.gate_up_proj, m_sizes=token_counts_by_expert
+            X = hidden_states, W = self.experts.gate_up_proj, m_sizes = token_counts_by_expert
         )
         assert first_gemm.shape == (total_tokens, 2 * self.experts.expert_dim)
 
@@ -220,7 +221,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
         # See comment above
         second_gemm = torch_grouped_gemm(
-            X=intermediate, W=self.experts.down_proj, m_sizes=token_counts_by_expert
+            X = intermediate, W = self.experts.down_proj, m_sizes = token_counts_by_expert
         )
         assert second_gemm.shape == (total_tokens, hidden_dim)
 
@@ -233,17 +234,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
         result = (
             Llama4MoeResult(
-                token_counts_by_expert=token_counts_by_expert,
-                gather_indices=gather_indices,
-                topk_weights=routing_weights,
-                hidden_states_after_weight_merge=hidden_states_after_weight_merge,
-                first_gemm=first_gemm,
-                intermediate=intermediate,
-                second_gemm=second_gemm,
-                hidden_states_unpermute=hidden_states_unpermute,
-                shared_expert_out=shared_expert_out,
-                final_out=final_out,
-                router_logits=router_logits,
+                token_counts_by_expert = token_counts_by_expert,
+                gather_indices = gather_indices,
+                topk_weights = routing_weights,
+                hidden_states_after_weight_merge = hidden_states_after_weight_merge,
+                first_gemm = first_gemm,
+                intermediate = intermediate,
+                second_gemm = second_gemm,
+                hidden_states_unpermute = hidden_states_unpermute,
+                shared_expert_out = shared_expert_out,
+                final_out = final_out,
+                router_logits = router_logits,
             )
             if self.debug
             else (final_out, routing_weights)
@@ -256,7 +257,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
     def __init__(
         self,
         config: Llama4TextConfig,
-        overlap_router_shared=False,
+        overlap_router_shared = False,
         permute_x: bool = False,
         permute_y: bool = True,
         autotune: bool = True,
@@ -265,12 +266,10 @@ def __init__(
         kernel_config_bwd_dX: KernelConfigBackward_dX = None,
         dW_only: bool = False,
         dX_only: bool = False,
-        verbose=False,
+        verbose = False,
     ):
-        super().__init__(config, overlap_router_shared=overlap_router_shared)
-        assert not permute_x, (
-            "Llama4 triton grouped gemm does not support permute x due to pre-multiplication of router weights"
-        )
+        super().__init__(config, overlap_router_shared = overlap_router_shared)
+        assert not permute_x, "Llama4 triton grouped gemm does not support permute x due to pre-multiplication of router weights"
         self.permute_x = permute_x
         self.permute_y = permute_y
         self.autotune = autotune
@@ -296,9 +295,9 @@ def copy_weights(self, other: Llama4TextMoe):
             if any(n in name for n in self.EXPERT_WEIGHT_NAMES):
                 param_to_copy = param_to_copy.permute(0, 2, 1)
 
-            assert param.shape == param_to_copy.shape, (
-                f"{param.shape} != {param_to_copy.shape}"
-            )
+            assert (
+                param.shape == param_to_copy.shape
+            ), f"{param.shape} != {param_to_copy.shape}"
             param.copy_(param_to_copy)
 
         return self
@@ -322,7 +321,7 @@ def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor:
         hidden_states = hidden_states.view(-1, self.hidden_dim)
         router_logits = self.router(hidden_states)
         routing_weights, selected_experts = torch.topk(
-            router_logits, self.top_k, dim=-1
+            router_logits, self.top_k, dim = -1
         )
 
         routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)
@@ -353,9 +352,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         router_logits, routing_weights, selected_experts = self.run_router(
             hidden_states
         )
-        assert routing_weights.shape == (num_tokens, self.top_k), (
-            f"{routing_weights.shape} != {(num_tokens, self.top_k)}"
-        )
+        assert routing_weights.shape == (
+            num_tokens,
+            self.top_k,
+        ), f"{routing_weights.shape} != {(num_tokens, self.top_k)}"
 
         if self.overlap_router_shared:
             with torch.cuda.stream(self.shared_expert_stream):
@@ -380,7 +380,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         )
 
         if self.top_k > 1:
-            hidden_states = hidden_states.sum(dim=1)
+            hidden_states = hidden_states.sum(dim = 1)
         hidden_states = hidden_states.view(-1, hidden_dim)
 
         # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
@@ -395,37 +395,37 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
         # Start expert computation
         hidden_states = grouped_gemm(
-            X=hidden_states,
-            W=self.experts.gate_up_proj,
-            m_sizes=token_counts_by_expert,
-            gather_indices=gather_indices,
-            topk=self.top_k,
-            permute_x=self.permute_x,
-            permute_y=False,  # output of first grouped gemm should never be permuted
-            autotune=self.autotune,
-            kernel_config_fwd=self.kernel_config_fwd,
-            kernel_config_bwd_dW=self.kernel_config_bwd_dW,
-            kernel_config_bwd_dX=self.kernel_config_bwd_dX,
-            is_first_gemm=True,
-            dW_only=self.dW_only,
-            dX_only=self.dX_only,
+            X = hidden_states,
+            W = self.experts.gate_up_proj,
+            m_sizes = token_counts_by_expert,
+            gather_indices = gather_indices,
+            topk = self.top_k,
+            permute_x = self.permute_x,
+            permute_y = False,  # output of first grouped gemm should never be permuted
+            autotune = self.autotune,
+            kernel_config_fwd = self.kernel_config_fwd,
+            kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+            kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+            is_first_gemm = True,
+            dW_only = self.dW_only,
+            dX_only = self.dX_only,
         )
         hidden_states = self.act_and_mul(hidden_states)
         hidden_states = grouped_gemm(
-            X=hidden_states,
-            W=self.experts.down_proj,
-            m_sizes=token_counts_by_expert,
-            gather_indices=gather_indices,
-            topk=self.top_k,
-            permute_x=False,
-            permute_y=self.permute_y,
-            autotune=self.autotune,
-            kernel_config_fwd=self.kernel_config_fwd,
-            kernel_config_bwd_dW=self.kernel_config_bwd_dW,
-            kernel_config_bwd_dX=self.kernel_config_bwd_dX,
-            is_first_gemm=False,
-            dW_only=self.dW_only,
-            dX_only=self.dX_only,
+            X = hidden_states,
+            W = self.experts.down_proj,
+            m_sizes = token_counts_by_expert,
+            gather_indices = gather_indices,
+            topk = self.top_k,
+            permute_x = False,
+            permute_y = self.permute_y,
+            autotune = self.autotune,
+            kernel_config_fwd = self.kernel_config_fwd,
+            kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+            kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+            is_first_gemm = False,
+            dW_only = self.dW_only,
+            dX_only = self.dX_only,
         )
 
         # Post-processing
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py b/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py
index 0ca4391b2..ace6a7714 100644
--- a/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py
+++ b/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py
@@ -78,8 +78,8 @@ def __init__(
         self.gate = torch.nn.Parameter(gate)
 
         # experts
-        self.gate_up_proj = torch.nn.Parameter(gate_up_proj, requires_grad=True)
-        self.down_proj = torch.nn.Parameter(down_proj, requires_grad=True)
+        self.gate_up_proj = torch.nn.Parameter(gate_up_proj, requires_grad = True)
+        self.down_proj = torch.nn.Parameter(down_proj, requires_grad = True)
         self.act_fn = ACT2FN[config.hidden_act]
 
     @staticmethod
@@ -90,17 +90,17 @@ def extract_hf_weights(moe_block: Qwen3MoeSparseMoeBlock):
         gate = moe_block.gate.weight.data
         gate_proj = torch.stack(
             [moe_block.experts[i].gate_proj.weight.data for i in range(num_experts)],
-            dim=0,
+            dim = 0,
         )
         up_proj = torch.stack(
             [moe_block.experts[i].up_proj.weight.data for i in range(num_experts)],
-            dim=0,
+            dim = 0,
         )
         down_proj = torch.stack(
             [moe_block.experts[i].down_proj.weight.data for i in range(num_experts)],
-            dim=0,
+            dim = 0,
         )
-        gate_up_proj = torch.cat([gate_proj, up_proj], dim=1)
+        gate_up_proj = torch.cat([gate_proj, up_proj], dim = 1)
         return gate, gate_up_proj, down_proj
 
     @classmethod
@@ -117,7 +117,7 @@ def check_weights(self, moe_block: Qwen3MoeSparseMoeBlock):
                         moe_block.experts[i].gate_proj.weight.data,
                         moe_block.experts[i].up_proj.weight.data,
                     ],
-                    dim=0,
+                    dim = 0,
                 )
             )
             assert self.down_proj[i].equal(moe_block.experts[i].down_proj.weight.data)
@@ -132,12 +132,12 @@ def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor:
         # router_logits: (batch * sequence_length, n_experts)
         router_logits = torch.nn.functional.linear(hidden_states, self.gate)
 
-        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+        routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)
         routing_weights, selected_experts = torch.topk(
-            routing_weights, self.top_k, dim=-1
+            routing_weights, self.top_k, dim = -1
         )
         if self.norm_topk_prob:  # only diff with mixtral sparse moe block!
-            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+            routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
         # we cast back to the input dtype
         routing_weights = routing_weights.to(hidden_states.dtype)
 
@@ -177,13 +177,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
         # Start expert computation
         first_gemm = torch_grouped_gemm(
-            X=hidden_states, W=self.gate_up_proj, m_sizes=token_counts_by_expert
+            X = hidden_states, W = self.gate_up_proj, m_sizes = token_counts_by_expert
         )
         assert first_gemm.shape == (total_tokens, 2 * self.moe_intermediate_size)
         intermediate = self.act_and_mul(first_gemm)
         assert intermediate.shape == (total_tokens, self.moe_intermediate_size)
         second_gemm = torch_grouped_gemm(
-            X=intermediate, W=self.down_proj, m_sizes=token_counts_by_expert
+            X = intermediate, W = self.down_proj, m_sizes = token_counts_by_expert
         )
         assert second_gemm.shape == (total_tokens, hidden_dim)
 
@@ -197,19 +197,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
             hidden_states_unpermute.view(num_tokens, self.top_k, hidden_dim)
             * routing_weights[..., None]
         )
-        hidden_states = hidden_states.sum(dim=1)
+        hidden_states = hidden_states.sum(dim = 1)
         assert hidden_states.shape == (num_tokens, hidden_dim)
 
         hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
         return GroupedGEMMResult(
-            token_counts_by_expert=token_counts_by_expert,
-            gather_indices=gather_indices,
-            topk_weights=routing_weights,
-            first_gemm=first_gemm,
-            intermediate=intermediate,
-            second_gemm=second_gemm,
-            hidden_states_unpermute=hidden_states_unpermute,
-            hidden_states=hidden_states,
+            token_counts_by_expert = token_counts_by_expert,
+            gather_indices = gather_indices,
+            topk_weights = routing_weights,
+            first_gemm = first_gemm,
+            intermediate = intermediate,
+            second_gemm = second_gemm,
+            hidden_states_unpermute = hidden_states_unpermute,
+            hidden_states = hidden_states,
         ), router_logits
 
 
@@ -267,14 +267,14 @@ def from_hf(
             gate,
             gate_up_proj,
             down_proj,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            autotune=autotune,
-            kernel_config_fwd=kernel_config_fwd,
-            kernel_config_bwd_dW=kernel_config_bwd_dW,
-            kernel_config_bwd_dX=kernel_config_bwd_dX,
-            dW_only=dW_only,
-            dX_only=dX_only,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            autotune = autotune,
+            kernel_config_fwd = kernel_config_fwd,
+            kernel_config_bwd_dW = kernel_config_bwd_dW,
+            kernel_config_bwd_dX = kernel_config_bwd_dX,
+            dW_only = dW_only,
+            dX_only = dX_only,
         )
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -299,37 +299,37 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
             hidden_states = permute(hidden_states, gather_indices, self.top_k)
         # Start expert computation
         hidden_states = grouped_gemm(
-            X=hidden_states,
-            W=self.gate_up_proj,
-            m_sizes=token_counts_by_expert,
-            gather_indices=gather_indices,
-            topk=self.top_k,
-            permute_x=self.permute_x,
-            permute_y=False,  # output of first grouped gemm should never be permuted
-            autotune=self.autotune,
-            kernel_config_fwd=self.kernel_config_fwd,
-            kernel_config_bwd_dW=self.kernel_config_bwd_dW,
-            kernel_config_bwd_dX=self.kernel_config_bwd_dX,
-            is_first_gemm=True,
-            dW_only=self.dW_only,
-            dX_only=self.dX_only,
+            X = hidden_states,
+            W = self.gate_up_proj,
+            m_sizes = token_counts_by_expert,
+            gather_indices = gather_indices,
+            topk = self.top_k,
+            permute_x = self.permute_x,
+            permute_y = False,  # output of first grouped gemm should never be permuted
+            autotune = self.autotune,
+            kernel_config_fwd = self.kernel_config_fwd,
+            kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+            kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+            is_first_gemm = True,
+            dW_only = self.dW_only,
+            dX_only = self.dX_only,
         )
         hidden_states = self.act_and_mul(hidden_states)
         hidden_states = grouped_gemm(
-            X=hidden_states,
-            W=self.down_proj,
-            m_sizes=token_counts_by_expert,
-            gather_indices=gather_indices,
-            topk=self.top_k,
-            permute_x=False,
-            permute_y=self.permute_y,
-            autotune=self.autotune,
-            kernel_config_fwd=self.kernel_config_fwd,
-            kernel_config_bwd_dW=self.kernel_config_bwd_dW,
-            kernel_config_bwd_dX=self.kernel_config_bwd_dX,
-            is_first_gemm=False,
-            dW_only=self.dW_only,
-            dX_only=self.dX_only,
+            X = hidden_states,
+            W = self.down_proj,
+            m_sizes = token_counts_by_expert,
+            gather_indices = gather_indices,
+            topk = self.top_k,
+            permute_x = False,
+            permute_y = self.permute_y,
+            autotune = self.autotune,
+            kernel_config_fwd = self.kernel_config_fwd,
+            kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+            kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+            is_first_gemm = False,
+            dW_only = self.dW_only,
+            dX_only = self.dX_only,
         )
 
         # Post-processing
@@ -342,7 +342,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
             hidden_states.view(num_tokens, self.top_k, hidden_dim)
             * routing_weights[..., None]
         )
-        hidden_states = hidden_states.sum(dim=1)
+        hidden_states = hidden_states.sum(dim = 1)
 
         hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
         return hidden_states, router_logits
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py b/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py
index cbccf19cb..0d497f380 100644
--- a/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py
+++ b/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py
@@ -80,14 +80,14 @@ def from_hf(
             gate,
             gate_up_proj,
             down_proj,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            autotune=autotune,
-            kernel_config_fwd=kernel_config_fwd,
-            kernel_config_bwd_dW=kernel_config_bwd_dW,
-            kernel_config_bwd_dX=kernel_config_bwd_dX,
-            dW_only=dW_only,
-            dX_only=dX_only,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            autotune = autotune,
+            kernel_config_fwd = kernel_config_fwd,
+            kernel_config_bwd_dW = kernel_config_bwd_dW,
+            kernel_config_bwd_dX = kernel_config_bwd_dX,
+            dW_only = dW_only,
+            dX_only = dX_only,
         )
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -112,37 +112,37 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
             hidden_states = permute(hidden_states, gather_indices, self.top_k)
         # Start expert computation
         hidden_states = grouped_gemm(
-            X=hidden_states,
-            W=self.gate_up_proj,
-            m_sizes=token_counts_by_expert,
-            gather_indices=gather_indices,
-            topk=self.top_k,
-            permute_x=self.permute_x,
-            permute_y=False,  # output of first grouped gemm should never be permuted
-            autotune=self.autotune,
-            kernel_config_fwd=self.kernel_config_fwd,
-            kernel_config_bwd_dW=self.kernel_config_bwd_dW,
-            kernel_config_bwd_dX=self.kernel_config_bwd_dX,
-            is_first_gemm=True,
-            dW_only=self.dW_only,
-            dX_only=self.dX_only,
+            X = hidden_states,
+            W = self.gate_up_proj,
+            m_sizes = token_counts_by_expert,
+            gather_indices = gather_indices,
+            topk = self.top_k,
+            permute_x = self.permute_x,
+            permute_y = False,  # output of first grouped gemm should never be permuted
+            autotune = self.autotune,
+            kernel_config_fwd = self.kernel_config_fwd,
+            kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+            kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+            is_first_gemm = True,
+            dW_only = self.dW_only,
+            dX_only = self.dX_only,
         )
         hidden_states = self.act_and_mul(hidden_states)
         hidden_states = grouped_gemm(
-            X=hidden_states,
-            W=self.down_proj,
-            m_sizes=token_counts_by_expert,
-            gather_indices=gather_indices,
-            topk=self.top_k,
-            permute_x=False,
-            permute_y=self.permute_y,
-            autotune=self.autotune,
-            kernel_config_fwd=self.kernel_config_fwd,
-            kernel_config_bwd_dW=self.kernel_config_bwd_dW,
-            kernel_config_bwd_dX=self.kernel_config_bwd_dX,
-            is_first_gemm=False,
-            dW_only=self.dW_only,
-            dX_only=self.dX_only,
+            X = hidden_states,
+            W = self.down_proj,
+            m_sizes = token_counts_by_expert,
+            gather_indices = gather_indices,
+            topk = self.top_k,
+            permute_x = False,
+            permute_y = self.permute_y,
+            autotune = self.autotune,
+            kernel_config_fwd = self.kernel_config_fwd,
+            kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+            kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+            is_first_gemm = False,
+            dW_only = self.dW_only,
+            dX_only = self.dX_only,
         )
 
         # Post-processing
@@ -155,7 +155,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
             hidden_states.view(num_tokens, self.top_k, hidden_dim)
             * routing_weights[..., None]
         )
-        hidden_states = hidden_states.sum(dim=1)
+        hidden_states = hidden_states.sum(dim = 1)
 
         hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
         return hidden_states, router_logits
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py b/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py
index 85ebf147f..46d9c3c51 100644
--- a/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py
+++ b/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py
@@ -52,9 +52,13 @@ def calculate_topk(
 
     def _activation(gating_output: torch.Tensor):
         if use_sigmoid:
-            scores = torch.sigmoid(gating_output.to(torch.float32)).to(gating_output.dtype)
+            scores = torch.sigmoid(gating_output.to(torch.float32)).to(
+                gating_output.dtype
+            )
         else:
-            scores = F.softmax(gating_output.to(torch.float32), dim=1).to(gating_output.dtype)
+            scores = F.softmax(gating_output.to(torch.float32), dim = 1).to(
+                gating_output.dtype
+            )
 
         return scores
 
@@ -63,19 +67,23 @@ def _activation(gating_output: torch.Tensor):
     else:
         scores = gating_output
 
-    topk_weights, topk_ids = torch.topk(scores, k=top_k, dim=1)
+    topk_weights, topk_ids = torch.topk(scores, k = top_k, dim = 1)
 
     if post_act:
         topk_weights = _activation(topk_weights)
 
     if renormalize:
-        topk_weights /= torch.sum(topk_weights, dim=-1, keepdim=True).to(gating_output.dtype)
+        topk_weights /= torch.sum(topk_weights, dim = -1, keepdim = True).to(
+            gating_output.dtype
+        )
 
     return topk_weights, topk_ids
 
 
 @torch.no_grad()
-def get_routing_indices(selected_experts, num_experts, return_scatter_indices: bool = False):
+def get_routing_indices(
+    selected_experts, num_experts, return_scatter_indices: bool = False
+):
     """
     Returns:
         token_counts_by_expert: [num_experts]
@@ -86,12 +94,12 @@ def get_routing_indices(selected_experts, num_experts, return_scatter_indices: b
     # group tokens together by expert indices from 0 to num_experts and pass that to experts forward
     token_counts_by_expert = torch.histc(
         selected_experts.view(-1),
-        bins=num_experts,
-        min=0,
-        max=num_experts,
+        bins = num_experts,
+        min = 0,
+        max = num_experts,
     )
     # token_indices_experts_sorted shape (bs*slen*top_k,)
-    gather_indices = torch.argsort(selected_experts.view(-1), stable=True)
+    gather_indices = torch.argsort(selected_experts.view(-1), stable = True)
     if return_scatter_indices:
         scatter_indices = gather_indices.argsort()
         return token_counts_by_expert, gather_indices, scatter_indices
@@ -99,7 +107,7 @@ def get_routing_indices(selected_experts, num_experts, return_scatter_indices: b
         return token_counts_by_expert, gather_indices
 
 
-def torch_grouped_gemm(X, W, m_sizes, transpose=True):
+def torch_grouped_gemm(X, W, m_sizes, transpose = True):
     """
     X: [M, K] if forward, else [M, N]
     W: [E, N, K]
@@ -119,7 +127,7 @@ def torch_grouped_gemm(X, W, m_sizes, transpose=True):
 
     N = W.shape[1]
 
-    result = torch.zeros((M, N), dtype=X.dtype, device=X.device)
+    result = torch.zeros((M, N), dtype = X.dtype, device = X.device)
 
     m_start = 0
     for g in range(E):
diff --git a/unsloth/kernels/moe/tests/common.py b/unsloth/kernels/moe/tests/common.py
index 67461b39f..bfe6f2094 100644
--- a/unsloth/kernels/moe/tests/common.py
+++ b/unsloth/kernels/moe/tests/common.py
@@ -18,7 +18,7 @@
 )
 
 
-def print_delimiter(char="-", length=80):
+def print_delimiter(char = "-", length = 80):
     print(char * length)
 
 
@@ -29,28 +29,28 @@ def delimiter_context():
     print_delimiter()
 
 
-def make_inputs(M, N, K, E, topk, dtype, requires_grad=False):
+def make_inputs(M, N, K, E, topk, dtype, requires_grad = False):
     X1 = (
-        torch.randn((M, K), device="cuda", dtype=dtype, requires_grad=requires_grad)
+        torch.randn((M, K), device = "cuda", dtype = dtype, requires_grad = requires_grad)
         / 10
     )
     X2 = (
         torch.randn(
-            (M * topk, N), device="cuda", dtype=dtype, requires_grad=requires_grad
+            (M * topk, N), device = "cuda", dtype = dtype, requires_grad = requires_grad
         )
         / 10
     )
     W1 = (
         torch.randn(
-            (E, 2 * N, K), device="cuda", dtype=dtype, requires_grad=requires_grad
+            (E, 2 * N, K), device = "cuda", dtype = dtype, requires_grad = requires_grad
         )
         / 10
     )
     W2 = (
-        torch.randn((E, K, N), device="cuda", dtype=dtype, requires_grad=requires_grad)
+        torch.randn((E, K, N), device = "cuda", dtype = dtype, requires_grad = requires_grad)
         / 10
     )
-    score = torch.randn((M, E), device="cuda", dtype=dtype, requires_grad=requires_grad)
+    score = torch.randn((M, E), device = "cuda", dtype = dtype, requires_grad = requires_grad)
     if requires_grad:
         X1.retain_grad()
         X2.retain_grad()
@@ -60,7 +60,7 @@ def make_inputs(M, N, K, E, topk, dtype, requires_grad=False):
     return X1, X2, W1, W2, score
 
 
-@dataclass(kw_only=True)
+@dataclass(kw_only = True)
 class DataConfig:
     seq_len: int
     dtype: torch.dtype
@@ -68,7 +68,7 @@ class DataConfig:
     bs: int = 1
 
 
-@dataclass(kw_only=True)
+@dataclass(kw_only = True)
 class ModelConfig:
     hidden_size: int
     intermediate_size: int
@@ -77,13 +77,13 @@ class ModelConfig:
     use_sigmoid: bool
     renormalize: bool
     pre_mul: bool = False
-    post_mul: bool = field(init=False)
+    post_mul: bool = field(init = False)
 
     def __post_init__(self):
         self.post_mul = not self.pre_mul
 
 
-@dataclass(kw_only=True)
+@dataclass(kw_only = True)
 class GroupedGEMMTestConfig:
     name: str = "test"
     data_config: DataConfig
@@ -105,7 +105,7 @@ def assert_equal(ref, tri):
         assert ref == tri, f"ref not equal to tri {ref} != {tri}"
 
 
-def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
+def assert_close(ref, tri, maxtol = None, rmstol = None, description = "--", verbose = True):
     if tri.dtype.itemsize == 1:
         ref_as_type = ref.to(tri.dtype)
         if ref.dtype == tri.dtype:
@@ -124,16 +124,16 @@ def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=T
     # cast to float32:
     ref = ref.to(torch.float32).detach()
     tri = tri.to(torch.float32).detach()
-    assert ref.shape == tri.shape, (
-        f"Tensors must have same size {ref.shape=} {tri.shape=}"
-    )
+    assert (
+        ref.shape == tri.shape
+    ), f"Tensors must have same size {ref.shape = } {tri.shape = }"
 
     # deal with infinite elements:
     inf_mask_ref = torch.isinf(ref)
     inf_mask_tri = torch.isinf(tri)
-    assert torch.equal(inf_mask_ref, inf_mask_tri), (
-        "Tensor must have same infinite elements"
-    )
+    assert torch.equal(
+        inf_mask_ref, inf_mask_tri
+    ), "Tensor must have same infinite elements"
     refn = torch.where(inf_mask_ref, 0, ref)
     trin = torch.where(inf_mask_tri, 0, tri)
 
@@ -182,11 +182,11 @@ def assert_indx_equal(ref, tri):
 
 
 def get_kernel_test_configs(
-    BLOCK_SIZE_M=32,
-    BLOCK_SIZE_N=32,
-    BLOCK_SIZE_K=32,
-    num_warps=4,
-    num_stages=2,
+    BLOCK_SIZE_M = 32,
+    BLOCK_SIZE_N = 32,
+    BLOCK_SIZE_K = 32,
+    num_warps = 4,
+    num_stages = 2,
 ) -> list[KernelConfig]:
     configs_fwd = []
     configs_bwd_dX = []
@@ -199,44 +199,44 @@ def get_kernel_test_configs(
                     for use_tma_store in [True, False]:
                         configs_fwd.append(
                             KernelConfigForward(
-                                BLOCK_SIZE_M=BLOCK_SIZE_M,
-                                BLOCK_SIZE_N=BLOCK_SIZE_N,
-                                BLOCK_SIZE_K=BLOCK_SIZE_K,
-                                num_warps=num_warps,
-                                num_stages=num_stages,
-                                use_tma_load_w=use_tma_load_w,
-                                use_tma_load_x=use_tma_load_x,
-                                use_tma_store=use_tma_store,
-                                permute_x=permute_x,
-                                permute_y=permute_y,
+                                BLOCK_SIZE_M = BLOCK_SIZE_M,
+                                BLOCK_SIZE_N = BLOCK_SIZE_N,
+                                BLOCK_SIZE_K = BLOCK_SIZE_K,
+                                num_warps = num_warps,
+                                num_stages = num_stages,
+                                use_tma_load_w = use_tma_load_w,
+                                use_tma_load_x = use_tma_load_x,
+                                use_tma_store = use_tma_store,
+                                permute_x = permute_x,
+                                permute_y = permute_y,
                             )
                         )
                         configs_bwd_dX.append(
                             KernelConfigBackward_dX(
-                                BLOCK_SIZE_M=BLOCK_SIZE_M,
-                                BLOCK_SIZE_N=BLOCK_SIZE_N,
-                                BLOCK_SIZE_K=BLOCK_SIZE_K,
-                                num_warps=num_warps,
-                                num_stages=num_stages,
-                                use_tma_load_dy=use_tma_load_x,
-                                use_tma_load_w=use_tma_load_w,
-                                permute_x=permute_x,
-                                permute_y=permute_y,
-                                use_tma_store=use_tma_store,
+                                BLOCK_SIZE_M = BLOCK_SIZE_M,
+                                BLOCK_SIZE_N = BLOCK_SIZE_N,
+                                BLOCK_SIZE_K = BLOCK_SIZE_K,
+                                num_warps = num_warps,
+                                num_stages = num_stages,
+                                use_tma_load_dy = use_tma_load_x,
+                                use_tma_load_w = use_tma_load_w,
+                                permute_x = permute_x,
+                                permute_y = permute_y,
+                                use_tma_store = use_tma_store,
                             )
                         )
                         configs_bwd_dW.append(
                             KernelConfigBackward_dW(
-                                BLOCK_SIZE_M=BLOCK_SIZE_M,
-                                BLOCK_SIZE_N=BLOCK_SIZE_N,
-                                BLOCK_SIZE_K=BLOCK_SIZE_K,
-                                num_warps=num_warps,
-                                num_stages=num_stages,
-                                use_tma_load_dy=use_tma_load_w,
-                                use_tma_load_x=use_tma_load_x,
-                                permute_x=permute_x,
-                                permute_y=permute_y,
-                                use_tma_store=use_tma_store,
+                                BLOCK_SIZE_M = BLOCK_SIZE_M,
+                                BLOCK_SIZE_N = BLOCK_SIZE_N,
+                                BLOCK_SIZE_K = BLOCK_SIZE_K,
+                                num_warps = num_warps,
+                                num_stages = num_stages,
+                                use_tma_load_dy = use_tma_load_w,
+                                use_tma_load_x = use_tma_load_x,
+                                permute_x = permute_x,
+                                permute_y = permute_y,
+                                use_tma_store = use_tma_store,
                             )
                         )
     configs_fwd = prune_kernel_configs_fwd(configs_fwd)
@@ -289,39 +289,39 @@ def remove_feature_flags(
 
 SMALL_MODEL_CONFIGS = [
     ModelConfig(
-        topk=topk,
-        num_experts=num_experts,
-        hidden_size=model_size[0],
-        intermediate_size=model_size[1],
-        use_sigmoid=False,
-        renormalize=False,
+        topk = topk,
+        num_experts = num_experts,
+        hidden_size = model_size[0],
+        intermediate_size = model_size[1],
+        use_sigmoid = False,
+        renormalize = False,
     )
     for topk, num_experts, model_size in itertools.product(
         TOPK, NUM_EXPERTS, TEST_MODEL_SIZES
     )
 ]
 LLAMA_MODEL_CONFIG = ModelConfig(
-    topk=1,
-    num_experts=16,
-    hidden_size=5120,
-    intermediate_size=8192,
-    use_sigmoid=True,
-    renormalize=False,
+    topk = 1,
+    num_experts = 16,
+    hidden_size = 5120,
+    intermediate_size = 8192,
+    use_sigmoid = True,
+    renormalize = False,
 )
 QWEN_MODEL_CONFIG = ModelConfig(
-    topk=8,
-    num_experts=128,
-    hidden_size=2048,
-    intermediate_size=768,
-    use_sigmoid=False,
-    renormalize=False,
+    topk = 8,
+    num_experts = 128,
+    hidden_size = 2048,
+    intermediate_size = 768,
+    use_sigmoid = False,
+    renormalize = False,
 )
 
 SEQLENS = [128, 1024]
 DTYPE = [torch.bfloat16]
 
 DATA_CONFIGS = [
-    DataConfig(seq_len=seq_len, dtype=dtype)
+    DataConfig(seq_len = seq_len, dtype = dtype)
     for seq_len, dtype in itertools.product(SEQLENS, DTYPE)
 ]
 KERNEL_CONFIGS_FWD, KERNEL_CONFIGS_BWD_dX, KERNEL_CONFIGS_BWD_dW = (
@@ -331,6 +331,6 @@ def remove_feature_flags(
 if __name__ == "__main__":
     print(
         KERNEL_CONFIGS_BWD_dX[0].to_string(
-            include_tuning_params=False, include_tma=False
+            include_tuning_params = False, include_tma = False
         )
     )
diff --git a/unsloth/kernels/moe/tests/moe_utils.py b/unsloth/kernels/moe/tests/moe_utils.py
index b32e4476c..26ac9fdb5 100644
--- a/unsloth/kernels/moe/tests/moe_utils.py
+++ b/unsloth/kernels/moe/tests/moe_utils.py
@@ -33,13 +33,13 @@ def rebind_experts_to_shared_buffer(
     dtype = moe_block.experts[0].down_proj.weight.dtype
 
     buffer_up = torch.empty(
-        num_experts, interm_size, hidden_size, device=device, dtype=dtype
+        num_experts, interm_size, hidden_size, device = device, dtype = dtype
     )
     buffer_gate = torch.empty(
-        num_experts, interm_size, hidden_size, device=device, dtype=dtype
+        num_experts, interm_size, hidden_size, device = device, dtype = dtype
     )
     buffer_down = torch.empty(
-        num_experts, hidden_size, interm_size, device=device, dtype=dtype
+        num_experts, hidden_size, interm_size, device = device, dtype = dtype
     )
 
     # Step 2: Copy existing expert weights into buffers
@@ -114,7 +114,7 @@ def check_down_proj_grad(
         test_grad = grouped_gemm_block.down_proj.grad[i]
         assert test_grad is not None
         diff = (ref_grad - test_grad).abs().max()
-        if not torch.allclose(ref_grad, test_grad, atol=atol, rtol=rtol):
+        if not torch.allclose(ref_grad, test_grad, atol = atol, rtol = rtol):
             print(f"expert {i} down_proj_grad_diff: {diff.detach().cpu().item():.6f}")
 
 
@@ -142,22 +142,22 @@ def check_gate_up_proj_grad(
         assert test_up_proj_grad is not None
 
         # Sanity check shapes
-        assert ref_gate_proj_grad.shape == test_gate_proj_grad.shape, (
-            f"{ref_gate_proj_grad.shape} != {test_gate_proj_grad.shape}"
-        )
-        assert ref_up_proj_grad.shape == test_up_proj_grad.shape, (
-            f"{ref_up_proj_grad.shape} != {test_up_proj_grad.shape}"
-        )
+        assert (
+            ref_gate_proj_grad.shape == test_gate_proj_grad.shape
+        ), f"{ref_gate_proj_grad.shape} != {test_gate_proj_grad.shape}"
+        assert (
+            ref_up_proj_grad.shape == test_up_proj_grad.shape
+        ), f"{ref_up_proj_grad.shape} != {test_up_proj_grad.shape}"
 
         # Check gradients
         diff = (ref_gate_proj_grad - test_gate_proj_grad).abs().max()
         if not torch.allclose(
-            ref_gate_proj_grad, test_gate_proj_grad, atol=atol, rtol=rtol
+            ref_gate_proj_grad, test_gate_proj_grad, atol = atol, rtol = rtol
         ):
             print(f"expert {i} gate_proj_grad_diff: {diff.detach().cpu().item():.6f}")
         diff = (ref_up_proj_grad - test_up_proj_grad).abs().max()
         if not torch.allclose(
-            ref_up_proj_grad, test_up_proj_grad, atol=atol, rtol=rtol
+            ref_up_proj_grad, test_up_proj_grad, atol = atol, rtol = rtol
         ):
             print(f"expert {i} up_proj_grad_diff: {diff.detach().cpu().item():.6f}")
 
@@ -173,7 +173,7 @@ def check_gate_grad(
     test_grad = grouped_gemm_block.gate.grad
     assert test_grad is not None
     diff = (ref_grad - test_grad).abs().max()
-    if not torch.allclose(ref_grad, test_grad, atol=atol, rtol=rtol):
+    if not torch.allclose(ref_grad, test_grad, atol = atol, rtol = rtol):
         print(f"gate_grad_diff: {diff.detach().cpu().item():.6f}")
 
 
@@ -199,9 +199,9 @@ def check_tensor_allclose(
     diff = (X_ref - X_test).abs().max()
     if verbose:
         print(f"{name} diff: {diff.detach().cpu().item():.6f}")
-    assert torch.allclose(X_ref, X_test, atol=atol, rtol=rtol), (
-        f"{name} diff: {diff.detach().cpu().item():.6f}"
-    )
+    assert torch.allclose(
+        X_ref, X_test, atol = atol, rtol = rtol
+    ), f"{name} diff: {diff.detach().cpu().item():.6f}"
 
 
 def check_expert_grads(
@@ -217,26 +217,26 @@ def check_expert_grads(
     for field in fields_to_check:
         ref_grads = getattr(ref_result, field)
         test_grads = getattr(test_result, field)
-        assert ref_grads.shape == test_grads.shape, (
-            f"{field}: {ref_grads.shape} != {test_grads.shape}"
-        )
+        assert (
+            ref_grads.shape == test_grads.shape
+        ), f"{field}: {ref_grads.shape} != {test_grads.shape}"
 
         # Test each expert
         for i in range(ref_grads.shape[0]):
             ref_grad = ref_grads[i]
             test_grad = test_grads[i]
             diff = (ref_grad - test_grad).abs().max()
-            assert torch.allclose(ref_grad, test_grad, atol=atol, rtol=rtol), (
-                f"{field}[{i}] diff: {diff.detach().cpu().item():.6f}"
-            )
+            assert torch.allclose(
+                ref_grad, test_grad, atol = atol, rtol = rtol
+            ), f"{field}[{i}] diff: {diff.detach().cpu().item():.6f}"
 
         # Test all experts
         diff = (ref_grads - test_grads).abs().max()
         if verbose:
             print(f"{field} diff: {diff.detach().cpu().item():.6f}")
-        assert torch.allclose(ref_grads, test_grads, atol=atol, rtol=rtol), (
-            f"{field} diff: {diff.detach().cpu().item():.6f}"
-        )
+        assert torch.allclose(
+            ref_grads, test_grads, atol = atol, rtol = rtol
+        ), f"{field} diff: {diff.detach().cpu().item():.6f}"
 
 
 def check_grads(
@@ -268,9 +268,9 @@ def check_fwd(
     diff = (ref_output - test_output).abs().max()
     if verbose:
         print(f"output diff: {diff.detach().cpu().item():.6f}")
-    assert torch.allclose(ref_output, test_output, atol=atol, rtol=rtol), (
-        f"output diff: {diff.detach().cpu().item():.6f}"
-    )
+    assert torch.allclose(
+        ref_output, test_output, atol = atol, rtol = rtol
+    ), f"output diff: {diff.detach().cpu().item():.6f}"
 
     # Check router logits
     ref_router_logits = ref_result.router_logits
@@ -279,7 +279,7 @@ def check_fwd(
     if verbose:
         print(f"router_logits diff: {diff.detach().cpu().item():.6f}")
     assert torch.allclose(
-        ref_router_logits, test_router_logits, atol=atol, rtol=rtol
+        ref_router_logits, test_router_logits, atol = atol, rtol = rtol
     ), f"router_logits diff: {diff.detach().cpu().item():.6f}"
 
 
@@ -304,9 +304,9 @@ def check_grouped_gemm_results(
         if verbose:
             print(f"{field.name} diff: {diff.detach().cpu().item():.6f}")
 
-        assert torch.allclose(ref_value, test_value, atol=atol, rtol=rtol), (
-            f"{field.name} diff: {diff.detach().cpu().item():.6f}"
-        )
+        assert torch.allclose(
+            ref_value, test_value, atol = atol, rtol = rtol
+        ), f"{field.name} diff: {diff.detach().cpu().item():.6f}"
 
 
 def run_forward(model: nn.Module, X: torch.Tensor, is_grouped_gemm: bool = False):
@@ -314,13 +314,13 @@ def run_forward(model: nn.Module, X: torch.Tensor, is_grouped_gemm: bool = False
     output, router_logits = model(X)
     if is_grouped_gemm:
         result = ForwardResult(
-            output=output.hidden_states,
-            router_logits=router_logits,
-            X=X,
-            grouped_gemm_result=output,
+            output = output.hidden_states,
+            router_logits = router_logits,
+            X = X,
+            grouped_gemm_result = output,
         )
     else:
-        result = ForwardResult(output=output, router_logits=router_logits, X=X)
+        result = ForwardResult(output = output, router_logits = router_logits, X = X)
     return result
 
 
@@ -344,16 +344,16 @@ def run_backward(
         )
     elif isinstance(model, Qwen3MoeGroupedGEMMBlock):
         gate_grad = model.gate.grad
-        gate_proj_grad, up_proj_grad = model.gate_up_proj.grad.chunk(2, dim=1)
+        gate_proj_grad, up_proj_grad = model.gate_up_proj.grad.chunk(2, dim = 1)
         down_proj_grad = model.down_proj.grad
     else:
         raise ValueError(f"Unsupported model type: {type(model)}")
     return BackwardResult(
-        X_grad=X.grad,
-        gate_grad=gate_grad,
-        gate_proj_grad=gate_proj_grad,
-        up_proj_grad=up_proj_grad,
-        down_proj_grad=down_proj_grad,
+        X_grad = X.grad,
+        gate_grad = gate_grad,
+        gate_proj_grad = gate_proj_grad,
+        up_proj_grad = up_proj_grad,
+        down_proj_grad = down_proj_grad,
     )
 
 
@@ -414,12 +414,12 @@ def from_hf(
             gate,
             gate_up_proj,
             down_proj,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            autotune=autotune,
-            kernel_config_fwd=kernel_config_fwd,
-            kernel_config_bwd_dW=kernel_config_bwd_dW,
-            kernel_config_bwd_dX=kernel_config_bwd_dX,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            autotune = autotune,
+            kernel_config_fwd = kernel_config_fwd,
+            kernel_config_bwd_dW = kernel_config_bwd_dW,
+            kernel_config_bwd_dX = kernel_config_bwd_dX,
         )
 
     def forward(self, hidden_states: torch.Tensor, debug: bool = False) -> torch.Tensor:
@@ -446,35 +446,35 @@ def forward(self, hidden_states: torch.Tensor, debug: bool = False) -> torch.Ten
 
         # Start expert computation
         first_gemm = grouped_gemm(
-            X=hidden_states,
-            W=self.gate_up_proj,
-            m_sizes=token_counts_by_expert,
-            gather_indices=gather_indices,
-            topk=self.top_k,
-            permute_x=self.permute_x,
-            permute_y=False,  # output of first grouped gemm should never be permuted
-            autotune=self.autotune,
-            kernel_config_fwd=self.kernel_config_fwd,
-            kernel_config_bwd_dW=self.kernel_config_bwd_dW,
-            kernel_config_bwd_dX=self.kernel_config_bwd_dX,
-            is_first_gemm=True,
+            X = hidden_states,
+            W = self.gate_up_proj,
+            m_sizes = token_counts_by_expert,
+            gather_indices = gather_indices,
+            topk = self.top_k,
+            permute_x = self.permute_x,
+            permute_y = False,  # output of first grouped gemm should never be permuted
+            autotune = self.autotune,
+            kernel_config_fwd = self.kernel_config_fwd,
+            kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+            kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+            is_first_gemm = True,
         )
         assert first_gemm.shape == (total_tokens, 2 * self.moe_intermediate_size)
         intermediate = self.act_and_mul(first_gemm)
         assert intermediate.shape == (total_tokens, self.moe_intermediate_size)
         second_gemm = grouped_gemm(
-            X=intermediate,
-            W=self.down_proj,
-            m_sizes=token_counts_by_expert,
-            gather_indices=gather_indices,
-            topk=self.top_k,
-            permute_x=False,
-            permute_y=self.permute_y,
-            autotune=self.autotune,
-            kernel_config_fwd=self.kernel_config_fwd,
-            kernel_config_bwd_dW=self.kernel_config_bwd_dW,
-            kernel_config_bwd_dX=self.kernel_config_bwd_dX,
-            is_first_gemm=False,
+            X = intermediate,
+            W = self.down_proj,
+            m_sizes = token_counts_by_expert,
+            gather_indices = gather_indices,
+            topk = self.top_k,
+            permute_x = False,
+            permute_y = self.permute_y,
+            autotune = self.autotune,
+            kernel_config_fwd = self.kernel_config_fwd,
+            kernel_config_bwd_dW = self.kernel_config_bwd_dW,
+            kernel_config_bwd_dX = self.kernel_config_bwd_dX,
+            is_first_gemm = False,
         )
         assert second_gemm.shape == (total_tokens, hidden_dim)
 
@@ -491,17 +491,17 @@ def forward(self, hidden_states: torch.Tensor, debug: bool = False) -> torch.Ten
             hidden_states_unpermute.view(num_tokens, self.top_k, hidden_dim)
             * routing_weights[..., None]
         )
-        hidden_states = hidden_states.sum(dim=1)
+        hidden_states = hidden_states.sum(dim = 1)
         assert hidden_states.shape == (num_tokens, hidden_dim)
 
         hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
         return GroupedGEMMResult(
-            token_counts_by_expert=token_counts_by_expert,
-            gather_indices=gather_indices,
-            topk_weights=routing_weights,
-            first_gemm=first_gemm,
-            intermediate=intermediate,
-            second_gemm=second_gemm,
-            hidden_states_unpermute=hidden_states_unpermute,
-            hidden_states=hidden_states,
+            token_counts_by_expert = token_counts_by_expert,
+            gather_indices = gather_indices,
+            topk_weights = routing_weights,
+            first_gemm = first_gemm,
+            intermediate = intermediate,
+            second_gemm = second_gemm,
+            hidden_states_unpermute = hidden_states_unpermute,
+            hidden_states = hidden_states,
         ), router_logits
diff --git a/unsloth/kernels/moe/tests/test_grouped_gemm.py b/unsloth/kernels/moe/tests/test_grouped_gemm.py
index 38b7e00ca..bd98b6a27 100644
--- a/unsloth/kernels/moe/tests/test_grouped_gemm.py
+++ b/unsloth/kernels/moe/tests/test_grouped_gemm.py
@@ -50,29 +50,29 @@
 # permute_y => permute the output of the grouped GEMM, only done for the second grouped GEMM
 # fuse_mul_post => fuse the multiplication of topk weights in the epilogue of the second grouped GEMM; only used for inference, not currently tested
 def check_valid_config(
-    permute_x, permute_y, use_W1, fuse_mul_post=False, is_backward=False, verbose=False
+    permute_x, permute_y, use_W1, fuse_mul_post = False, is_backward = False, verbose = False
 ):
     use_W2 = not use_W1
 
     if permute_x and permute_y:
         if verbose:
-            print(f"Skipping test: {permute_x=} {permute_y=}")
+            print(f"Skipping test: {permute_x = } {permute_y = }")
         return False
     if use_W2 and permute_x:
         if verbose:
-            print(f"Skipping test: {permute_x=} {use_W2=}")
+            print(f"Skipping test: {permute_x = } {use_W2 = }")
         return False
     if use_W1 and permute_y:
         if verbose:
-            print(f"Skipping test: {permute_y=} {use_W1=}")
+            print(f"Skipping test: {permute_y = } {use_W1 = }")
         return False
     if fuse_mul_post and use_W1:
         if verbose:
-            print(f"Skipping test: {fuse_mul_post=} {use_W1=}")
+            print(f"Skipping test: {fuse_mul_post = } {use_W1 = }")
         return False
     if is_backward and fuse_mul_post:
         if verbose:
-            print(f"Skipping test: {fuse_mul_post=} {is_backward=}")
+            print(f"Skipping test: {fuse_mul_post = } {is_backward = }")
         return False
 
     return True
@@ -122,22 +122,22 @@ def _test_grouped_gemm_forward(
     use_autograd: bool = False,
 ):
     if not check_valid_config(
-        permute_x, permute_y, use_W1=use_W1, fuse_mul_post=fuse_mul_post
+        permute_x, permute_y, use_W1 = use_W1, fuse_mul_post = fuse_mul_post
     ):
         pytest.skip(
-            f"Skipping test due to invalid config: {permute_x=} {permute_y=} {use_W1=} {fuse_mul_post=}"
+            f"Skipping test due to invalid config: {permute_x = } {permute_y = } {use_W1 = } {fuse_mul_post = }"
         )
 
     if use_tma_store and not allow_tma_store:
         pytest.skip("TMA store needs to be debugged due to non-deterministic behavior")
 
     X1, X2, W1, W2, gating_output = make_inputs(
-        M=data_config.bs * data_config.seq_len,
-        N=model_config.intermediate_size,
-        K=model_config.hidden_size,
-        E=model_config.num_experts,
-        topk=model_config.topk,
-        dtype=data_config.dtype,
+        M = data_config.bs * data_config.seq_len,
+        N = model_config.intermediate_size,
+        K = model_config.hidden_size,
+        E = model_config.num_experts,
+        topk = model_config.topk,
+        dtype = data_config.dtype,
     )
     topk = model_config.topk
     use_sigmoid = model_config.use_sigmoid
@@ -150,24 +150,26 @@ def _test_grouped_gemm_forward(
     W = W1 if use_W1 else W2
 
     if use_W1:
-        assert X.shape == (num_tokens, K), (
-            f"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}"
-        )
+        assert X.shape == (
+            num_tokens,
+            K,
+        ), f"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}"
     else:
-        assert X.shape == (num_tokens * topk, N), (
-            f"X.shape: {X.shape}, num_tokens: {num_tokens}, topk: {topk}, N: {N}"
-        )
+        assert X.shape == (
+            num_tokens * topk,
+            N,
+        ), f"X.shape: {X.shape}, num_tokens: {num_tokens}, topk: {topk}, N: {N}"
 
     total_tokens = num_tokens * topk
     output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)
 
     topk_weights, topk_ids = calculate_topk(
-        gating_output, topk, use_sigmoid=use_sigmoid, renormalize=renormalize
+        gating_output, topk, use_sigmoid = use_sigmoid, renormalize = renormalize
     )
     topk_weights = topk_weights.view(-1)  # num_tokens * topk
     topk_ids = topk_ids.view(-1)  # num_tokens * topk
 
-    expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts=E)
+    expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts = E)
     assert len(gather_indices) == total_tokens
     assert len(expert_token_counts) == E
 
@@ -177,11 +179,11 @@ def _test_grouped_gemm_forward(
 
     Xref = Xperm
 
-    assert Xperm.shape == (total_tokens, K) if use_W1 else (total_tokens, N), (
-        f"Xperm.shape: {Xperm.shape}, total_tokens: {total_tokens}, K: {K}"
-    )
+    assert (
+        Xperm.shape == (total_tokens, K) if use_W1 else (total_tokens, N)
+    ), f"Xperm.shape: {Xperm.shape}, total_tokens: {total_tokens}, K: {K}"
 
-    ref_output = torch_grouped_gemm(X=Xref, W=W, m_sizes=expert_token_counts)
+    ref_output = torch_grouped_gemm(X = Xref, W = W, m_sizes = expert_token_counts)
 
     if permute_x:
         X_test = X
@@ -202,55 +204,55 @@ def _test_grouped_gemm_forward(
         from grouped_gemm.interface import grouped_gemm
 
         kernel_config_fwd = KernelConfigForward(
-            BLOCK_SIZE_M=BLOCK_SIZE_M,
-            BLOCK_SIZE_N=BLOCK_SIZE_N,
-            BLOCK_SIZE_K=BLOCK_SIZE_K,
-            num_warps=num_warps,
-            num_stages=num_stages,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            fuse_mul_post=fuse_mul_post,
-            use_tma_load_w=use_tma_load_w,
-            use_tma_load_x=use_tma_load_x,
-            use_tma_store=use_tma_store,
+            BLOCK_SIZE_M = BLOCK_SIZE_M,
+            BLOCK_SIZE_N = BLOCK_SIZE_N,
+            BLOCK_SIZE_K = BLOCK_SIZE_K,
+            num_warps = num_warps,
+            num_stages = num_stages,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            fuse_mul_post = fuse_mul_post,
+            use_tma_load_w = use_tma_load_w,
+            use_tma_load_x = use_tma_load_x,
+            use_tma_store = use_tma_store,
         )
 
         test_output = grouped_gemm(
-            X=X_test,
-            W=W,
-            topk=topk,
-            m_sizes=expert_token_counts,
-            gather_indices=gather_indices,
-            topk_weights=topk_weights if fuse_mul_post else None,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            fuse_mul_post=fuse_mul_post,
-            kernel_config_fwd=kernel_config_fwd,
-            autotune=autotune,
-            is_first_gemm=use_W1,
+            X = X_test,
+            W = W,
+            topk = topk,
+            m_sizes = expert_token_counts,
+            gather_indices = gather_indices,
+            topk_weights = topk_weights if fuse_mul_post else None,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            fuse_mul_post = fuse_mul_post,
+            kernel_config_fwd = kernel_config_fwd,
+            autotune = autotune,
+            is_first_gemm = use_W1,
         )
     # Use manual interface
     else:
         test_output = grouped_gemm_forward(
-            X=X_test,
-            W=W,
-            topk=topk,
-            m_sizes=expert_token_counts,
-            gather_indices=gather_indices,
-            topk_weights=topk_weights if fuse_mul_post else None,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            fuse_mul_post=fuse_mul_post,
-            use_tma_load_w=use_tma_load_w,
-            use_tma_load_x=use_tma_load_x,
-            use_tma_store=use_tma_store,
-            autotune=autotune,
-            BLOCK_SIZE_M=BLOCK_SIZE_M,
-            BLOCK_SIZE_N=BLOCK_SIZE_N,
-            BLOCK_SIZE_K=BLOCK_SIZE_K,
-            num_warps=num_warps,
-            num_stages=num_stages,
-            flatten=flatten,
+            X = X_test,
+            W = W,
+            topk = topk,
+            m_sizes = expert_token_counts,
+            gather_indices = gather_indices,
+            topk_weights = topk_weights if fuse_mul_post else None,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            fuse_mul_post = fuse_mul_post,
+            use_tma_load_w = use_tma_load_w,
+            use_tma_load_x = use_tma_load_x,
+            use_tma_store = use_tma_store,
+            autotune = autotune,
+            BLOCK_SIZE_M = BLOCK_SIZE_M,
+            BLOCK_SIZE_N = BLOCK_SIZE_N,
+            BLOCK_SIZE_K = BLOCK_SIZE_K,
+            num_warps = num_warps,
+            num_stages = num_stages,
+            flatten = flatten,
         )
     assert ref_output.shape == output_shape
     assert test_output.shape == output_shape
@@ -265,26 +267,26 @@ def _test_grouped_gemm_forward(
             test_output = unpermute(test_output, gather_indices)
         ref_output = ref_output * topk_weights[:, None]
 
-    assert torch.allclose(ref_output, test_output, atol=atol, rtol=rtol), (
-        f"Grouped gemm forward failed: {(ref_output - test_output).abs().max().item():.6f}"
-    )
+    assert torch.allclose(
+        ref_output, test_output, atol = atol, rtol = rtol
+    ), f"Grouped gemm forward failed: {(ref_output - test_output).abs().max().item():.6f}"
 
 
 # NOTE: Fuse multiplication of topk weights is only supported for inference and not training, although this may change in the future; not currently tested.
 @pytest.mark.parametrize(
     "kernel_config",
     KERNEL_CONFIGS_FWD,
-    ids=lambda x: x.to_string(include_tuning_params=True, include_tma=True),
+    ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),
 )
 @pytest.mark.parametrize(
     "model_config",
     SMALL_MODEL_CONFIGS + [QWEN_MODEL_CONFIG, LLAMA_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_forward_manual(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -292,9 +294,9 @@ def test_grouped_gemm_forward_manual(
     use_W1: bool,
 ):
     _test_grouped_gemm_forward(
-        data_config=data_config,
-        model_config=model_config,
-        use_W1=use_W1,
+        data_config = data_config,
+        model_config = model_config,
+        use_W1 = use_W1,
         **asdict(kernel_config),
     )
 
@@ -302,17 +304,17 @@ def test_grouped_gemm_forward_manual(
 @pytest.mark.parametrize(
     "kernel_config",
     KERNEL_CONFIGS_FWD,
-    ids=lambda x: x.to_string(include_tuning_params=True, include_tma=True),
+    ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),
 )
 @pytest.mark.parametrize(
     "model_config",
     SMALL_MODEL_CONFIGS + [QWEN_MODEL_CONFIG, LLAMA_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_forward_manual_autograd(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -320,32 +322,32 @@ def test_grouped_gemm_forward_manual_autograd(
     use_W1: bool,
 ):
     _test_grouped_gemm_forward(
-        data_config=data_config,
-        model_config=model_config,
-        use_W1=use_W1,
-        use_autograd=True,
+        data_config = data_config,
+        model_config = model_config,
+        use_W1 = use_W1,
+        use_autograd = True,
         **asdict(kernel_config),
     )
 
 
 @pytest.mark.parametrize(
-    "num_autotune_configs", [10], ids=lambda x: f"num_autotune_configs={x}"
+    "num_autotune_configs", [10], ids = lambda x: f"num_autotune_configs={x}"
 )
 @pytest.mark.parametrize(
-    "permute_x", [True, False], ids=lambda x: "permute_x" if x else ""
+    "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
 )
 @pytest.mark.parametrize(
-    "permute_y", [True, False], ids=lambda x: "permute_y" if x else ""
+    "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
 )
 @pytest.mark.parametrize(
     "model_config",
     [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_forward_autotune(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -355,35 +357,35 @@ def test_grouped_gemm_forward_autotune(
     num_autotune_configs: int,
 ):
     _test_grouped_gemm_forward(
-        data_config=data_config,
-        model_config=model_config,
-        permute_x=permute_x,
-        permute_y=permute_y,
-        use_W1=use_W1,
-        num_autotune_configs=num_autotune_configs,
-        autotune=True,
-        use_autograd=False,
+        data_config = data_config,
+        model_config = model_config,
+        permute_x = permute_x,
+        permute_y = permute_y,
+        use_W1 = use_W1,
+        num_autotune_configs = num_autotune_configs,
+        autotune = True,
+        use_autograd = False,
     )
 
 
 @pytest.mark.parametrize(
-    "num_autotune_configs", [10], ids=lambda x: f"num_autotune_configs={x}"
+    "num_autotune_configs", [10], ids = lambda x: f"num_autotune_configs={x}"
 )
 @pytest.mark.parametrize(
-    "permute_x", [True, False], ids=lambda x: "permute_x" if x else ""
+    "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
 )
 @pytest.mark.parametrize(
-    "permute_y", [True, False], ids=lambda x: "permute_y" if x else ""
+    "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
 )
 @pytest.mark.parametrize(
     "model_config",
     [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_forward_autotune_autograd(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -393,14 +395,14 @@ def test_grouped_gemm_forward_autotune_autograd(
     num_autotune_configs: int,
 ):
     _test_grouped_gemm_forward(
-        data_config=data_config,
-        model_config=model_config,
-        permute_x=permute_x,
-        permute_y=permute_y,
-        use_W1=use_W1,
-        num_autotune_configs=num_autotune_configs,
-        autotune=True,
-        use_autograd=True,
+        data_config = data_config,
+        model_config = model_config,
+        permute_x = permute_x,
+        permute_y = permute_y,
+        use_W1 = use_W1,
+        num_autotune_configs = num_autotune_configs,
+        autotune = True,
+        use_autograd = True,
     )
 
 
@@ -459,9 +461,9 @@ def _test_grouped_gemm_backward_dX(
     use_autograd: bool = False,
     fuse_mul_post: bool = False,
 ):
-    if not check_valid_config(permute_x, permute_y, use_W1=use_W1, is_backward=True):
+    if not check_valid_config(permute_x, permute_y, use_W1 = use_W1, is_backward = True):
         pytest.skip(
-            f"Skipping test due to invalid config: {permute_x=} {permute_y=} {use_W1=}"
+            f"Skipping test due to invalid config: {permute_x = } {permute_y = } {use_W1 = }"
         )
 
     if use_tma_store and not allow_tma_store:
@@ -482,13 +484,13 @@ def _test_grouped_gemm_backward_dX(
 
     use_W2 = not use_W1
     X1, X2, W1, W2, gating_output = make_inputs(
-        M=data_config.bs * data_config.seq_len,
-        N=model_config.intermediate_size,
-        K=model_config.hidden_size,
-        E=model_config.num_experts,
-        topk=model_config.topk,
-        dtype=data_config.dtype,
-        requires_grad=True,
+        M = data_config.bs * data_config.seq_len,
+        N = model_config.intermediate_size,
+        K = model_config.hidden_size,
+        E = model_config.num_experts,
+        topk = model_config.topk,
+        dtype = data_config.dtype,
+        requires_grad = True,
     )
     topk = model_config.topk
     num_experts = model_config.num_experts
@@ -504,23 +506,25 @@ def _test_grouped_gemm_backward_dX(
     W = W1 if use_W1 else W2
 
     if use_W1:
-        assert X.shape == (num_tokens, K), (
-            f"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}"
-        )
+        assert X.shape == (
+            num_tokens,
+            K,
+        ), f"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}"
     else:
-        assert X.shape == (total_tokens, N), (
-            f"X.shape: {X.shape}, total_tokens: {total_tokens}, N: {N}"
-        )
+        assert X.shape == (
+            total_tokens,
+            N,
+        ), f"X.shape: {X.shape}, total_tokens: {total_tokens}, N: {N}"
 
     W_test = W.detach().clone().requires_grad_(True)
 
     topk_weights, topk_ids = calculate_topk(
-        gating_output, topk, use_sigmoid=use_sigmoid, renormalize=renormalize
+        gating_output, topk, use_sigmoid = use_sigmoid, renormalize = renormalize
     )
     topk_weights = topk_weights.view(-1)  # num_tokens * topk
     topk_ids = topk_ids.view(-1)  # num_tokens * topk
 
-    expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts=E)
+    expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts = E)
     assert len(gather_indices) == total_tokens
     assert len(expert_token_counts) == num_experts
 
@@ -535,10 +539,10 @@ def _test_grouped_gemm_backward_dX(
     assert Xperm.shape == (total_tokens, K) if use_W1 else (total_tokens, N)
 
     output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)
-    ref_output = torch_grouped_gemm(X=Xperm, W=W, m_sizes=expert_token_counts)
-    assert ref_output.shape == output_shape, (
-        f"ref_output.shape: {ref_output.shape}, output_shape: {output_shape}"
-    )
+    ref_output = torch_grouped_gemm(X = Xperm, W = W, m_sizes = expert_token_counts)
+    assert (
+        ref_output.shape == output_shape
+    ), f"ref_output.shape: {ref_output.shape}, output_shape: {output_shape}"
 
     if permute_y:
         ref_output = unpermute(ref_output, gather_indices)
@@ -566,14 +570,14 @@ def _test_grouped_gemm_backward_dX(
         if not autotune:
             kernel_config_fwd = KernelConfigForward()
             kernel_config_bwd_dX = KernelConfigBackward_dX(
-                use_tma_load_dy=use_tma_load_dy,
-                use_tma_load_w=use_tma_load_w,
-                use_tma_store=use_tma_store,
-                BLOCK_SIZE_M=BLOCK_SIZE_M,
-                BLOCK_SIZE_N=BLOCK_SIZE_N,
-                BLOCK_SIZE_K=BLOCK_SIZE_K,
-                num_warps=num_warps,
-                num_stages=num_stages,
+                use_tma_load_dy = use_tma_load_dy,
+                use_tma_load_w = use_tma_load_w,
+                use_tma_store = use_tma_store,
+                BLOCK_SIZE_M = BLOCK_SIZE_M,
+                BLOCK_SIZE_N = BLOCK_SIZE_N,
+                BLOCK_SIZE_K = BLOCK_SIZE_K,
+                num_warps = num_warps,
+                num_stages = num_stages,
             )
             kernel_config_bwd_dW = KernelConfigBackward_dW()
         else:
@@ -603,25 +607,25 @@ def _test_grouped_gemm_backward_dX(
             else Xperm.detach().clone().requires_grad_(True)
         )
         test_output = grouped_gemm(
-            X=X_,
-            W=W_test,
-            m_sizes=expert_token_counts,
-            gather_indices=gather_indices,
-            topk=topk,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            autotune=autotune,
-            kernel_config_fwd=kernel_config_fwd,
-            kernel_config_bwd_dX=kernel_config_bwd_dX,
-            is_first_gemm=use_W1,
-            dX_only=True,
-        )
-        assert test_output.shape == ref_output.shape, (
-            f"test_output.shape: {test_output.shape}, ref_output.shape: {ref_output.shape}"
-        )
-        assert torch.allclose(test_output, ref_output, atol=atol, rtol=rtol), (
-            f"Grouped gemm backward_dX forward outputs mismatch: {(test_output - ref_output).abs().max().item():.6f}"
+            X = X_,
+            W = W_test,
+            m_sizes = expert_token_counts,
+            gather_indices = gather_indices,
+            topk = topk,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            autotune = autotune,
+            kernel_config_fwd = kernel_config_fwd,
+            kernel_config_bwd_dX = kernel_config_bwd_dX,
+            is_first_gemm = use_W1,
+            dX_only = True,
         )
+        assert (
+            test_output.shape == ref_output.shape
+        ), f"test_output.shape: {test_output.shape}, ref_output.shape: {ref_output.shape}"
+        assert torch.allclose(
+            test_output, ref_output, atol = atol, rtol = rtol
+        ), f"Grouped gemm backward_dX forward outputs mismatch: {(test_output - ref_output).abs().max().item():.6f}"
         test_output.backward(grad_output)
         assert X_.grad is not None
 
@@ -631,40 +635,40 @@ def _test_grouped_gemm_backward_dX(
         # This is due to the fact that torch autograd handles unpermute and sum reduction differently see: https://discuss.pytorch.org/t/permute-unpermute-gradient/219557    else:
         if permute_x and use_W1:
             X_grad_unperm = unpermute(Xperm.grad, gather_indices)
-            manual_grad_check = X_grad_unperm.view(num_tokens, topk, K).sum(dim=1)
-            assert manual_grad_check.shape == X_.grad.shape, (
-                f"manual_grad_check.shape: {manual_grad_check.shape}, X_.grad.shape: {X_.grad.shape}"
-            )
-            assert torch.allclose(manual_grad_check, X_.grad, atol=atol, rtol=rtol), (
-                f"Grouped gemm backward_dX forward outputs mismatch: {(manual_grad_check - X_.grad).abs().max().item():.6f}"
-            )
+            manual_grad_check = X_grad_unperm.view(num_tokens, topk, K).sum(dim = 1)
+            assert (
+                manual_grad_check.shape == X_.grad.shape
+            ), f"manual_grad_check.shape: {manual_grad_check.shape}, X_.grad.shape: {X_.grad.shape}"
+            assert torch.allclose(
+                manual_grad_check, X_.grad, atol = atol, rtol = rtol
+            ), f"Grouped gemm backward_dX forward outputs mismatch: {(manual_grad_check - X_.grad).abs().max().item():.6f}"
             manual_diff = (X_.grad - manual_grad_check).abs().max().item()
             autograd_diff = (X_.grad - X.grad).abs().max().item()
             print(f"manual_diff: {manual_diff:.6f}, autograd_diff: {autograd_diff:.6f}")
         else:
-            assert torch.allclose(X_.grad, ref_grad, atol=atol, rtol=rtol), (
-                f"Grouped gemm backward_dX forward outputs mismatch: {(X_.grad - ref_grad).abs().max().item():.6f}"
-            )
+            assert torch.allclose(
+                X_.grad, ref_grad, atol = atol, rtol = rtol
+            ), f"Grouped gemm backward_dX forward outputs mismatch: {(X_.grad - ref_grad).abs().max().item():.6f}"
         return
     else:
         dX_test = grouped_gemm_dX(
-            dY=grad_output,
-            W=W_test,
-            gather_indices=gather_indices,
-            m_sizes=expert_token_counts,
-            topk=topk,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            use_tma_load_w=use_tma_load_w,
-            use_tma_load_dy=use_tma_load_dy,
-            use_tma_store=use_tma_store,
-            autotune=autotune,
-            BLOCK_SIZE_M=BLOCK_SIZE_M,
-            BLOCK_SIZE_N=BLOCK_SIZE_N,
-            BLOCK_SIZE_K=BLOCK_SIZE_K,
-            num_warps=num_warps,
-            num_stages=num_stages,
-            flatten=flatten,
+            dY = grad_output,
+            W = W_test,
+            gather_indices = gather_indices,
+            m_sizes = expert_token_counts,
+            topk = topk,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            use_tma_load_w = use_tma_load_w,
+            use_tma_load_dy = use_tma_load_dy,
+            use_tma_store = use_tma_store,
+            autotune = autotune,
+            BLOCK_SIZE_M = BLOCK_SIZE_M,
+            BLOCK_SIZE_N = BLOCK_SIZE_N,
+            BLOCK_SIZE_K = BLOCK_SIZE_K,
+            num_warps = num_warps,
+            num_stages = num_stages,
+            flatten = flatten,
             # debug=True,
         )
 
@@ -673,21 +677,21 @@ def _test_grouped_gemm_backward_dX(
     if permute_x and use_W1:
         ref_grad = unpermute(ref_grad, gather_indices)
 
-    assert ref_grad.shape == dX_test.shape, (
-        f"Grouped gemm manual backward_dX outputs mismatch: ref_grad: {ref_grad.shape}, dX_test: {dX_test.shape}"
-    )
+    assert (
+        ref_grad.shape == dX_test.shape
+    ), f"Grouped gemm manual backward_dX outputs mismatch: ref_grad: {ref_grad.shape}, dX_test: {dX_test.shape}"
     diff = (ref_grad - dX_test).abs().max().item()
 
-    assert torch.allclose(ref_grad, dX_test, atol=atol, rtol=rtol), (
-        f"Grouped gemm manual backward_dX outputs mismatch: {diff:.6f}"
-    )
+    assert torch.allclose(
+        ref_grad, dX_test, atol = atol, rtol = rtol
+    ), f"Grouped gemm manual backward_dX outputs mismatch: {diff:.6f}"
 
     if permute_x and use_W1:
         # Show that reduction results in diffs
         # First calculate X.grad manually by backpropping through unpermuted ref_grad
-        dX_ref_check = ref_grad.view(num_tokens, topk, K).sum(dim=1)
+        dX_ref_check = ref_grad.view(num_tokens, topk, K).sum(dim = 1)
         # Do the same for the actual output of the kernel
-        dX_test_check = dX_test.view(num_tokens, topk, K).sum(dim=1)
+        dX_test_check = dX_test.view(num_tokens, topk, K).sum(dim = 1)
         # Show diffs for each combination
         diff_ref_check = (X.grad - dX_ref_check).abs().max().item()
         diff_test_check = (X.grad - dX_test_check).abs().max().item()
@@ -702,17 +706,17 @@ def _test_grouped_gemm_backward_dX(
 @pytest.mark.parametrize(
     "kernel_config",
     KERNEL_CONFIGS_BWD_dX,
-    ids=lambda x: x.to_string(include_tuning_params=True, include_tma=True),
+    ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),
 )
 @pytest.mark.parametrize(
     "model_config",
     SMALL_MODEL_CONFIGS[:1] + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_backward_dX_manual(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -720,10 +724,10 @@ def test_grouped_gemm_backward_dX_manual(
     use_W1: bool,
 ):
     _test_grouped_gemm_backward_dX(
-        data_config=data_config,
-        model_config=model_config,
-        use_W1=use_W1,
-        use_autograd=False,
+        data_config = data_config,
+        model_config = model_config,
+        use_W1 = use_W1,
+        use_autograd = False,
         **asdict(kernel_config),
     )
 
@@ -731,17 +735,17 @@ def test_grouped_gemm_backward_dX_manual(
 @pytest.mark.parametrize(
     "kernel_config",
     KERNEL_CONFIGS_BWD_dX,
-    ids=lambda x: x.to_string(include_tuning_params=True, include_tma=True),
+    ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),
 )
 @pytest.mark.parametrize(
     "model_config",
     SMALL_MODEL_CONFIGS[:1] + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_backward_dX_manual_autograd(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -749,32 +753,32 @@ def test_grouped_gemm_backward_dX_manual_autograd(
     use_W1: bool,
 ):
     _test_grouped_gemm_backward_dX(
-        data_config=data_config,
-        model_config=model_config,
-        use_W1=use_W1,
-        use_autograd=True,
+        data_config = data_config,
+        model_config = model_config,
+        use_W1 = use_W1,
+        use_autograd = True,
         **asdict(kernel_config),
     )
 
 
 @pytest.mark.parametrize(
-    "num_autotune_configs", [20], ids=lambda x: f"num_autotune_configs={x}"
+    "num_autotune_configs", [20], ids = lambda x: f"num_autotune_configs={x}"
 )
 @pytest.mark.parametrize(
-    "permute_x", [True, False], ids=lambda x: "permute_x" if x else ""
+    "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
 )
 @pytest.mark.parametrize(
-    "permute_y", [True, False], ids=lambda x: "permute_y" if x else ""
+    "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
 )
 @pytest.mark.parametrize(
     "model_config",
     [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_backward_dX_autotune(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -785,35 +789,35 @@ def test_grouped_gemm_backward_dX_autotune(
 ):
     # TMA loads / stores will be autotuned
     _test_grouped_gemm_backward_dX(
-        data_config=data_config,
-        model_config=model_config,
-        permute_x=permute_x,
-        permute_y=permute_y,
-        use_W1=use_W1,
-        autotune=True,
-        use_autograd=False,
-        num_autotune_configs=num_autotune_configs,
+        data_config = data_config,
+        model_config = model_config,
+        permute_x = permute_x,
+        permute_y = permute_y,
+        use_W1 = use_W1,
+        autotune = True,
+        use_autograd = False,
+        num_autotune_configs = num_autotune_configs,
     )
 
 
 @pytest.mark.parametrize(
-    "num_autotune_configs", [20], ids=lambda x: f"num_autotune_configs={x}"
+    "num_autotune_configs", [20], ids = lambda x: f"num_autotune_configs={x}"
 )
 @pytest.mark.parametrize(
-    "permute_x", [True, False], ids=lambda x: "permute_x" if x else ""
+    "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
 )
 @pytest.mark.parametrize(
-    "permute_y", [True, False], ids=lambda x: "permute_y" if x else ""
+    "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
 )
 @pytest.mark.parametrize(
     "model_config",
     [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_backward_dX_autotune_autograd(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -824,14 +828,14 @@ def test_grouped_gemm_backward_dX_autotune_autograd(
 ):
     # TMA loads / stores will be autotuned
     _test_grouped_gemm_backward_dX(
-        data_config=data_config,
-        model_config=model_config,
-        permute_x=permute_x,
-        permute_y=permute_y,
-        use_W1=use_W1,
-        autotune=True,
-        use_autograd=True,
-        num_autotune_configs=num_autotune_configs,
+        data_config = data_config,
+        model_config = model_config,
+        permute_x = permute_x,
+        permute_y = permute_y,
+        use_W1 = use_W1,
+        autotune = True,
+        use_autograd = True,
+        num_autotune_configs = num_autotune_configs,
     )
 
 
@@ -860,25 +864,25 @@ def _test_grouped_gemm_backward_dW(
     if not check_valid_config(
         permute_x,
         permute_y,
-        fuse_mul_post=fuse_mul_post,
-        use_W1=use_W1,
-        is_backward=True,
+        fuse_mul_post = fuse_mul_post,
+        use_W1 = use_W1,
+        is_backward = True,
     ):
         pytest.skip(
-            f"Skipping test due to invalid config: {permute_x=} {permute_y=} {use_W1=}"
+            f"Skipping test due to invalid config: {permute_x = } {permute_y = } {use_W1 = }"
         )
 
     if use_tma_store and not allow_tma_store:
         pytest.skip("TMA store needs to be debugged due to non-deterministic behavior")
 
     X1, X2, W1, W2, gating_output = make_inputs(
-        M=data_config.bs * data_config.seq_len,
-        N=model_config.intermediate_size,
-        K=model_config.hidden_size,
-        E=model_config.num_experts,
-        topk=model_config.topk,
-        dtype=data_config.dtype,
-        requires_grad=True,
+        M = data_config.bs * data_config.seq_len,
+        N = model_config.intermediate_size,
+        K = model_config.hidden_size,
+        E = model_config.num_experts,
+        topk = model_config.topk,
+        dtype = data_config.dtype,
+        requires_grad = True,
     )
     topk = model_config.topk
     num_experts = model_config.num_experts
@@ -892,13 +896,15 @@ def _test_grouped_gemm_backward_dW(
     W = W1 if use_W1 else W2
 
     if use_W1:
-        assert X.shape == (num_tokens, K), (
-            f"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}"
-        )
+        assert X.shape == (
+            num_tokens,
+            K,
+        ), f"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}"
     else:
-        assert X.shape == (num_tokens * topk, N), (
-            f"X.shape: {X.shape}, num_tokens: {num_tokens}, topk: {topk}, N: {N}"
-        )
+        assert X.shape == (
+            num_tokens * topk,
+            N,
+        ), f"X.shape: {X.shape}, num_tokens: {num_tokens}, topk: {topk}, N: {N}"
 
     total_tokens = num_tokens * topk
     output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)
@@ -907,12 +913,12 @@ def _test_grouped_gemm_backward_dW(
     W_test = W.detach().clone().requires_grad_(True)
 
     topk_weights, topk_ids = calculate_topk(
-        gating_output, topk, use_sigmoid=use_sigmoid, renormalize=renormalize
+        gating_output, topk, use_sigmoid = use_sigmoid, renormalize = renormalize
     )
     topk_weights = topk_weights.view(-1)  # num_tokens * topk
     topk_ids = topk_ids.view(-1)  # num_tokens * topk
 
-    expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts=E)
+    expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts = E)
     assert len(gather_indices) == total_tokens
     assert len(expert_token_counts) == num_experts
 
@@ -928,7 +934,7 @@ def _test_grouped_gemm_backward_dW(
 
     output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)
 
-    ref_output = torch_grouped_gemm(X=Xperm, W=W, m_sizes=expert_token_counts)
+    ref_output = torch_grouped_gemm(X = Xperm, W = W, m_sizes = expert_token_counts)
     assert ref_output.shape == output_shape
 
     # if permute_y then the assumption is that the output of grouped_gemm was unpermuted on store
@@ -945,7 +951,7 @@ def _test_grouped_gemm_backward_dW(
     X_ = X_test if permute_x else Xperm_test
 
     if debug:
-        torch.set_printoptions(precision=4)
+        torch.set_printoptions(precision = 4)
         for i in range(num_experts):
             print(f"Expert {i} weight grad:\n{W.grad[i, :5, :5]}")
 
@@ -963,24 +969,24 @@ def _test_grouped_gemm_backward_dW(
         if not autotune:
             kernel_config_fwd = KernelConfigForward(
                 # Only care about backward_dW config
-                use_tma_load_w=False,
-                use_tma_load_x=False,
-                use_tma_store=False,
-                BLOCK_SIZE_M=BLOCK_SIZE_M,
-                BLOCK_SIZE_N=BLOCK_SIZE_N,
-                BLOCK_SIZE_K=BLOCK_SIZE_K,
-                num_warps=num_warps,
-                num_stages=num_stages,
+                use_tma_load_w = False,
+                use_tma_load_x = False,
+                use_tma_store = False,
+                BLOCK_SIZE_M = BLOCK_SIZE_M,
+                BLOCK_SIZE_N = BLOCK_SIZE_N,
+                BLOCK_SIZE_K = BLOCK_SIZE_K,
+                num_warps = num_warps,
+                num_stages = num_stages,
             )
             kernel_config_bwd_dW = KernelConfigBackward_dW(
-                use_tma_load_dy=use_tma_load_dy,
-                use_tma_load_x=use_tma_load_x,
-                use_tma_store=use_tma_store,
-                BLOCK_SIZE_M=BLOCK_SIZE_M,
-                BLOCK_SIZE_N=BLOCK_SIZE_N,
-                BLOCK_SIZE_K=BLOCK_SIZE_K,
-                num_warps=num_warps,
-                num_stages=num_stages,
+                use_tma_load_dy = use_tma_load_dy,
+                use_tma_load_x = use_tma_load_x,
+                use_tma_store = use_tma_store,
+                BLOCK_SIZE_M = BLOCK_SIZE_M,
+                BLOCK_SIZE_N = BLOCK_SIZE_N,
+                BLOCK_SIZE_K = BLOCK_SIZE_K,
+                num_warps = num_warps,
+                num_stages = num_stages,
             )
         else:
             from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dW_kernel
@@ -1001,56 +1007,56 @@ def _test_grouped_gemm_backward_dW(
             kernel_config_bwd_dW = None
 
         test_output = grouped_gemm(
-            X=X_,
-            W=W_test,
-            m_sizes=expert_token_counts,
-            gather_indices=gather_indices,
-            topk=topk,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            kernel_config_fwd=kernel_config_fwd,
-            kernel_config_bwd_dW=kernel_config_bwd_dW,
-            autotune=autotune,
-            is_first_gemm=use_W1,
-            dW_only=True,
-        )
-        assert test_output.shape == ref_output.shape, (
-            f"Grouped gemm autograd backward_dW outputs mismatch: {test_output.shape} != {ref_output.shape}"
-        )
-        assert torch.allclose(test_output, ref_output, atol=atol, rtol=rtol), (
-            f"Grouped gemm autograd backward_dW forward outputs mismatch: {test_output.shape} != {ref_output.shape}"
+            X = X_,
+            W = W_test,
+            m_sizes = expert_token_counts,
+            gather_indices = gather_indices,
+            topk = topk,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            kernel_config_fwd = kernel_config_fwd,
+            kernel_config_bwd_dW = kernel_config_bwd_dW,
+            autotune = autotune,
+            is_first_gemm = use_W1,
+            dW_only = True,
         )
+        assert (
+            test_output.shape == ref_output.shape
+        ), f"Grouped gemm autograd backward_dW outputs mismatch: {test_output.shape} != {ref_output.shape}"
+        assert torch.allclose(
+            test_output, ref_output, atol = atol, rtol = rtol
+        ), f"Grouped gemm autograd backward_dW forward outputs mismatch: {test_output.shape} != {ref_output.shape}"
         test_output.backward(grad_output)
         assert W_test.grad is not None
         dW_test = W_test.grad
     else:
         dW_test = grouped_gemm_dW(
-            dY=grad_output,
-            X=X_,
-            m_sizes=expert_token_counts,
-            gather_indices=gather_indices,
-            topk=topk,
-            permute_x=permute_x,
-            permute_y=permute_y,
-            use_tma_load_dy=use_tma_load_dy,
-            use_tma_load_x=use_tma_load_x,
-            use_tma_store=use_tma_store,
-            BLOCK_SIZE_M=BLOCK_SIZE_M,
-            BLOCK_SIZE_N=BLOCK_SIZE_N,
-            BLOCK_SIZE_K=BLOCK_SIZE_K,
-            num_warps=num_warps,
-            num_stages=num_stages,
-            flatten=flatten,
-            autotune=autotune,
-            debug=debug,
+            dY = grad_output,
+            X = X_,
+            m_sizes = expert_token_counts,
+            gather_indices = gather_indices,
+            topk = topk,
+            permute_x = permute_x,
+            permute_y = permute_y,
+            use_tma_load_dy = use_tma_load_dy,
+            use_tma_load_x = use_tma_load_x,
+            use_tma_store = use_tma_store,
+            BLOCK_SIZE_M = BLOCK_SIZE_M,
+            BLOCK_SIZE_N = BLOCK_SIZE_N,
+            BLOCK_SIZE_K = BLOCK_SIZE_K,
+            num_warps = num_warps,
+            num_stages = num_stages,
+            flatten = flatten,
+            autotune = autotune,
+            debug = debug,
         )
-    assert W.grad.shape == dW_test.shape, (
-        f"Grouped gemm manual backward_dW outputs mismatch: W.grad: {W.grad.shape}, dW_test: {dW_test.shape}"
-    )
+    assert (
+        W.grad.shape == dW_test.shape
+    ), f"Grouped gemm manual backward_dW outputs mismatch: W.grad: {W.grad.shape}, dW_test: {dW_test.shape}"
 
     if debug:
         with torch.no_grad():
-            if not torch.allclose(W.grad, dW_test, atol=atol, rtol=rtol):
+            if not torch.allclose(W.grad, dW_test, atol = atol, rtol = rtol):
                 print(f"Ref Wgrad sum: {W.grad.sum().item():.4f}")
             print(f"Test Wgrad sum: {dW_test.sum().item():.4f}")
 
@@ -1061,30 +1067,30 @@ def _test_grouped_gemm_backward_dW(
                 print(f"Expert {i} diff: {expert_diff:.6f}")
 
             diff = (W.grad - dW_test).abs().max().item()
-            assert False, (
-                f"Grouped gemm manual backward_dW outputs mismatch: {diff:.6f}"
-            )
+            assert (
+                False
+            ), f"Grouped gemm manual backward_dW outputs mismatch: {diff:.6f}"
     else:
         diff = (W.grad - dW_test).abs().max().item()
-        assert torch.allclose(W.grad, dW_test, atol=atol, rtol=rtol), (
-            f"Grouped gemm manual backward_dW outputs mismatch: {diff:.6f}"
-        )
+        assert torch.allclose(
+            W.grad, dW_test, atol = atol, rtol = rtol
+        ), f"Grouped gemm manual backward_dW outputs mismatch: {diff:.6f}"
 
 
 @pytest.mark.parametrize(
     "kernel_config",
     KERNEL_CONFIGS_BWD_dW,
-    ids=lambda x: x.to_string(include_tuning_params=False, include_tma=True),
+    ids = lambda x: x.to_string(include_tuning_params = False, include_tma = True),
 )
 @pytest.mark.parametrize(
     "model_config",
     SMALL_MODEL_CONFIGS + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_backward_dW_manual(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -1093,10 +1099,10 @@ def test_grouped_gemm_backward_dW_manual(
     debug: bool = False,
 ):
     _test_grouped_gemm_backward_dW(
-        data_config=data_config,
-        model_config=model_config,
-        use_W1=use_W1,
-        use_autograd=False,
+        data_config = data_config,
+        model_config = model_config,
+        use_W1 = use_W1,
+        use_autograd = False,
         **asdict(kernel_config),
     )
 
@@ -1104,17 +1110,17 @@ def test_grouped_gemm_backward_dW_manual(
 @pytest.mark.parametrize(
     "kernel_config",
     KERNEL_CONFIGS_BWD_dW,
-    ids=lambda x: x.to_string(include_tuning_params=False, include_tma=True),
+    ids = lambda x: x.to_string(include_tuning_params = False, include_tma = True),
 )
 @pytest.mark.parametrize(
     "model_config",
     SMALL_MODEL_CONFIGS + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_backward_dW_manual_autograd(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -1123,32 +1129,32 @@ def test_grouped_gemm_backward_dW_manual_autograd(
     debug: bool = False,
 ):
     _test_grouped_gemm_backward_dW(
-        data_config=data_config,
-        model_config=model_config,
-        use_W1=use_W1,
-        use_autograd=True,
+        data_config = data_config,
+        model_config = model_config,
+        use_W1 = use_W1,
+        use_autograd = True,
         **asdict(kernel_config),
     )
 
 
 @pytest.mark.parametrize(
-    "num_autotune_configs", [20], ids=lambda x: f"num_autotune_configs={x}"
+    "num_autotune_configs", [20], ids = lambda x: f"num_autotune_configs={x}"
 )
 @pytest.mark.parametrize(
-    "permute_x", [True, False], ids=lambda x: "permute_x" if x else ""
+    "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
 )
 @pytest.mark.parametrize(
-    "permute_y", [True, False], ids=lambda x: "permute_y" if x else ""
+    "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
 )
 @pytest.mark.parametrize(
     "model_config",
     [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_backward_dW_autotune(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -1158,35 +1164,35 @@ def test_grouped_gemm_backward_dW_autotune(
     num_autotune_configs: int,
 ):
     _test_grouped_gemm_backward_dW(
-        data_config=data_config,
-        model_config=model_config,
-        use_W1=use_W1,
-        permute_x=permute_x,
-        permute_y=permute_y,
-        autotune=True,
-        use_autograd=False,
-        num_autotune_configs=num_autotune_configs,
+        data_config = data_config,
+        model_config = model_config,
+        use_W1 = use_W1,
+        permute_x = permute_x,
+        permute_y = permute_y,
+        autotune = True,
+        use_autograd = False,
+        num_autotune_configs = num_autotune_configs,
     )
 
 
 @pytest.mark.parametrize(
-    "num_autotune_configs", [20], ids=lambda x: f"num_autotune_configs={x}"
+    "num_autotune_configs", [20], ids = lambda x: f"num_autotune_configs={x}"
 )
 @pytest.mark.parametrize(
-    "permute_x", [True, False], ids=lambda x: "permute_x" if x else ""
+    "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
 )
 @pytest.mark.parametrize(
-    "permute_y", [True, False], ids=lambda x: "permute_y" if x else ""
+    "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
 )
 @pytest.mark.parametrize(
     "model_config",
     [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
-    ids=lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
+    ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
 )
 @pytest.mark.parametrize(
-    "data_config", DATA_CONFIGS, ids=lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
+    "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
 )
-@pytest.mark.parametrize("use_W1", [True, False], ids=lambda x: f"use_W1={x}")
+@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
 def test_grouped_gemm_backward_dW_autotune_autograd(
     data_config: DataConfig,
     model_config: ModelConfig,
@@ -1196,12 +1202,12 @@ def test_grouped_gemm_backward_dW_autotune_autograd(
     num_autotune_configs: int,
 ):
     _test_grouped_gemm_backward_dW(
-        data_config=data_config,
-        model_config=model_config,
-        use_W1=use_W1,
-        permute_x=permute_x,
-        permute_y=permute_y,
-        autotune=True,
-        use_autograd=True,
-        num_autotune_configs=num_autotune_configs,
+        data_config = data_config,
+        model_config = model_config,
+        use_W1 = use_W1,
+        permute_x = permute_x,
+        permute_y = permute_y,
+        autotune = True,
+        use_autograd = True,
+        num_autotune_configs = num_autotune_configs,
     )
diff --git a/unsloth/kernels/moe/tests/test_llama4_moe.py b/unsloth/kernels/moe/tests/test_llama4_moe.py
index 27d5d99a2..13ad552bf 100644
--- a/unsloth/kernels/moe/tests/test_llama4_moe.py
+++ b/unsloth/kernels/moe/tests/test_llama4_moe.py
@@ -37,7 +37,7 @@
 
 
 @contextmanager
-def annotated_context(prelude, epilogue="Passed!", char="-", num_chars=80):
+def annotated_context(prelude, epilogue = "Passed!", char = "-", num_chars = 80):
     print(char * num_chars)
     print(prelude)
     yield
@@ -81,7 +81,7 @@ def prep_triton_kernel_traits(autotune):
 
 
 def sparse_to_dense(t: torch.Tensor):
-    t = t.sum(dim=0).view(-1)
+    t = t.sum(dim = 0).view(-1)
     return t
 
 
@@ -91,9 +91,9 @@ def _check_diff(
     t2: torch.Tensor,
     atol,
     rtol,
-    precision=".6f",
-    verbose=False,
-    msg="",
+    precision = ".6f",
+    verbose = False,
+    msg = "",
 ):
     t2 = t2.view_as(t1)
     diff = t1.sub(t2).abs().max().item()
@@ -101,7 +101,7 @@ def _check_diff(
         if msg == "":
             msg = "diff"
         print(f"{msg}: {diff:{precision}}")
-    assert torch.allclose(t1, t2, atol=atol, rtol=rtol)
+    assert torch.allclose(t1, t2, atol = atol, rtol = rtol)
 
 
 def run_backwards(y: torch.Tensor, grad_output: torch.Tensor, module: torch.nn.Module):
@@ -115,19 +115,19 @@ def _check_grads(
     m2: torch.nn.Module,
     atol,
     rtol,
-    precision=".6f",
-    verbose=False,
-    msg="",
+    precision = ".6f",
+    verbose = False,
+    msg = "",
 ):
     for name, param in m1.named_parameters():
         _check_diff(
             param.grad,
             m2.get_parameter(name).grad,
-            atol=atol,
-            rtol=rtol,
-            precision=precision,
-            verbose=verbose,
-            msg=f"{msg}:{name}.grad",
+            atol = atol,
+            rtol = rtol,
+            precision = precision,
+            verbose = verbose,
+            msg = f"{msg}:{name}.grad",
         )
 
 
@@ -139,19 +139,19 @@ def model_config():
 @pytest.mark.parametrize(
     "overlap_router_shared",
     [False, True],
-    ids=lambda x: "overlap_router_shared" if x else "no_overlap",
+    ids = lambda x: "overlap_router_shared" if x else "no_overlap",
 )
 @pytest.mark.parametrize(
-    "permute_y", [False, True], ids=lambda x: "permute_y" if x else "no_permute_y"
+    "permute_y", [False, True], ids = lambda x: "permute_y" if x else "no_permute_y"
 )
 @pytest.mark.parametrize(
-    "permute_x", [False], ids=lambda x: "permute_x" if x else "no_permute_x"
+    "permute_x", [False], ids = lambda x: "permute_x" if x else "no_permute_x"
 )  # Llama4 does not support permute_x
 @pytest.mark.parametrize(
-    "autotune", [True], ids=lambda x: "autotune" if x else "manual"
+    "autotune", [True], ids = lambda x: "autotune" if x else "manual"
 )
-@pytest.mark.parametrize("seqlen", SEQ_LENS, ids=lambda x: f"seqlen={x}")
-@pytest.mark.parametrize("dtype", DTYPES, ids=str)
+@pytest.mark.parametrize("seqlen", SEQ_LENS, ids = lambda x: f"seqlen={x}")
+@pytest.mark.parametrize("dtype", DTYPES, ids = str)
 def test_llama4_ref(
     dtype: torch.dtype,
     seqlen,
@@ -161,9 +161,9 @@ def test_llama4_ref(
     overlap_router_shared: bool,
     model_config: Llama4TextConfig,  # test fixture
     bs: int = 1,
-    device="cuda",
-    precision=".6f",
-    verbose=False,
+    device = "cuda",
+    precision = ".6f",
+    verbose = False,
 ):
     torch.manual_seed(
         SEED
@@ -172,24 +172,24 @@ def test_llama4_ref(
     hidden_dim = model_config.hidden_size
     atol, rtol = TOLERANCES[dtype]
     check_diff = partial(
-        _check_diff, atol=atol, rtol=rtol, precision=precision, verbose=verbose
+        _check_diff, atol = atol, rtol = rtol, precision = precision, verbose = verbose
     )
     check_grads = partial(
-        _check_grads, atol=atol, rtol=rtol, precision=precision, verbose=verbose
+        _check_grads, atol = atol, rtol = rtol, precision = precision, verbose = verbose
     )
 
     # Reference op -- HF
-    llama4_ref = Llama4TextMoe(model_config).to(dtype=dtype, device=device)
+    llama4_ref = Llama4TextMoe(model_config).to(dtype = dtype, device = device)
 
     # Torch grouped gemm impl
     llama4_gg_ref = Llama4GroupedGemmTextMoe(
-        model_config, overlap_router_shared=overlap_router_shared
-    ).to(dtype=dtype, device=device)
+        model_config, overlap_router_shared = overlap_router_shared
+    ).to(dtype = dtype, device = device)
     llama4_gg_ref.copy_weights(llama4_ref)
     llama4_gg_ref.check_weights(llama4_ref)
 
     x_ref = torch.randn(
-        bs, seqlen, hidden_dim, dtype=dtype, device=device, requires_grad=True
+        bs, seqlen, hidden_dim, dtype = dtype, device = device, requires_grad = True
     )
     x_torch_gg = x_ref.detach().clone().requires_grad_()
     x_triton = x_ref.detach().clone().requires_grad_()
@@ -198,9 +198,9 @@ def test_llama4_ref(
     y_torch_gg, routing_torch_gg = llama4_gg_ref(x_torch_gg)
     assert y_ref.shape == y_torch_gg.shape, f"{y_ref.shape} != {y_torch_gg.shape}"
     with annotated_context("Testing torch grouped gemm Llama4TextMoe"):
-        check_diff(y_ref, y_torch_gg, msg="y_torch_gg")
+        check_diff(y_ref, y_torch_gg, msg = "y_torch_gg")
         check_diff(
-            sparse_to_dense(routing_ref), routing_torch_gg, msg="routing_torch_gg"
+            sparse_to_dense(routing_ref), routing_torch_gg, msg = "routing_torch_gg"
         )
 
     kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX = (
@@ -209,38 +209,38 @@ def test_llama4_ref(
 
     llama4_triton = Llama4TritonTextMoe(
         model_config,
-        overlap_router_shared=overlap_router_shared,
-        permute_x=permute_x,
-        permute_y=permute_y,
-        autotune=autotune,
-        kernel_config_fwd=kernel_config_fwd,
-        kernel_config_bwd_dW=kernel_config_bwd_dW,
-        kernel_config_bwd_dX=kernel_config_bwd_dX,
-    ).to(device=device, dtype=dtype)
+        overlap_router_shared = overlap_router_shared,
+        permute_x = permute_x,
+        permute_y = permute_y,
+        autotune = autotune,
+        kernel_config_fwd = kernel_config_fwd,
+        kernel_config_bwd_dW = kernel_config_bwd_dW,
+        kernel_config_bwd_dX = kernel_config_bwd_dX,
+    ).to(device = device, dtype = dtype)
     llama4_triton.copy_weights(llama4_ref)
     llama4_triton.check_weights(llama4_ref)
 
     y_triton, routing_triton = llama4_triton(x_triton)
     with annotated_context("Testing triton grouped gemm Llama4TextMoe forward"):
-        check_diff(y_ref, y_triton, msg="y_triton")
-        check_diff(sparse_to_dense(routing_ref), routing_triton, msg="routing_triton")
+        check_diff(y_ref, y_triton, msg = "y_triton")
+        check_diff(sparse_to_dense(routing_ref), routing_triton, msg = "routing_triton")
 
     ref_grad = torch.randn_like(y_ref)
     run_backwards(y_ref, ref_grad, llama4_ref)
     run_backwards(y_torch_gg, ref_grad, llama4_gg_ref)
     with annotated_context("Testing torch group gemm Llama4TextMoe backward"):
-        check_grads(llama4_ref, llama4_gg_ref, msg="torch_gg")
+        check_grads(llama4_ref, llama4_gg_ref, msg = "torch_gg")
 
     run_backwards(y_triton, ref_grad, llama4_triton)
     with annotated_context("Testing triton group gemm Llama4TextMoe backward"):
-        check_grads(llama4_ref, llama4_triton, msg="triton")
+        check_grads(llama4_ref, llama4_triton, msg = "triton")
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--seqlen", type=int, default=1024)
+    parser.add_argument("--seqlen", type = int, default = 1024)
     parser.add_argument(
-        "--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16"
+        "--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
     )
     args = parser.parse_args()
     args.dtype = getattr(torch, args.dtype)
@@ -251,12 +251,12 @@ def test_llama4_ref(
     text_config: Llama4TextConfig = get_text_config(model_id)
     for overlap in [False, True]:
         test_llama4_ref(
-            seqlen=args.seqlen,
-            model_config=text_config,
-            dtype=args.dtype,
-            autotune=True,
-            permute_x=False,
-            permute_y=True,
-            overlap_router_shared=overlap,
-            verbose=True,
+            seqlen = args.seqlen,
+            model_config = text_config,
+            dtype = args.dtype,
+            autotune = True,
+            permute_x = False,
+            permute_y = True,
+            overlap_router_shared = overlap,
+            verbose = True,
         )
diff --git a/unsloth/kernels/moe/tests/test_qwen3_moe.py b/unsloth/kernels/moe/tests/test_qwen3_moe.py
index 0e1eee990..42a8356c0 100644
--- a/unsloth/kernels/moe/tests/test_qwen3_moe.py
+++ b/unsloth/kernels/moe/tests/test_qwen3_moe.py
@@ -68,18 +68,18 @@
 }
 
 
-@pytest.fixture(scope="module")
+@pytest.fixture(scope = "module")
 def model_id():
     return "Qwen/Qwen3-30B-A3B"
 
 
-@pytest.fixture(scope="module")
+@pytest.fixture(scope = "module")
 def config(model_id: str):
     return AutoConfig.from_pretrained(model_id)
 
 
 @contextmanager
-def annotated_context(prelude, epilogue="Passed!", char="-", num_chars=80):
+def annotated_context(prelude, epilogue = "Passed!", char = "-", num_chars = 80):
     print(char * num_chars)
     print(prelude)
     yield
@@ -96,16 +96,16 @@ def annotated_context(prelude, epilogue="Passed!", char="-", num_chars=80):
 
 
 @pytest.mark.parametrize(
-    "permute_y", [True], ids=lambda x: "permute_y" if x else "no_permute_y"
+    "permute_y", [True], ids = lambda x: "permute_y" if x else "no_permute_y"
 )
 @pytest.mark.parametrize(
-    "permute_x", [True], ids=lambda x: "permute_x" if x else "no_permute_x"
+    "permute_x", [True], ids = lambda x: "permute_x" if x else "no_permute_x"
 )
 @pytest.mark.parametrize(
-    "autotune", [True], ids=lambda x: "autotune" if x else "manual"
+    "autotune", [True], ids = lambda x: "autotune" if x else "manual"
 )
-@pytest.mark.parametrize("seqlen", SEQ_LENS, ids=lambda x: f"seqlen={x}")
-@pytest.mark.parametrize("dtype", DTYPES, ids=str)
+@pytest.mark.parametrize("seqlen", SEQ_LENS, ids = lambda x: f"seqlen={x}")
+@pytest.mark.parametrize("dtype", DTYPES, ids = str)
 def test_qwen3_moe(
     config: Qwen3MoeConfig,
     seqlen: int,
@@ -157,36 +157,36 @@ def test_qwen3_moe(
     # Triton kernel grouped gemm version of MoE Block -- this is what we're testing
     fused_gemm_block = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
         moe_block,
-        permute_x=permute_x,
-        permute_y=permute_y,
-        autotune=autotune,
-        kernel_config_fwd=kernel_config_fwd,
-        kernel_config_bwd_dW=kernel_config_bwd_dW,
-        kernel_config_bwd_dX=kernel_config_bwd_dX,
+        permute_x = permute_x,
+        permute_y = permute_y,
+        autotune = autotune,
+        kernel_config_fwd = kernel_config_fwd,
+        kernel_config_bwd_dW = kernel_config_bwd_dW,
+        kernel_config_bwd_dX = kernel_config_bwd_dX,
     ).to(device, dtype)
     fused_gemm_block.check_weights(moe_block)
 
     X = torch.randn(
-        bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True
+        bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
     )
 
     # Forward
-    ref_result = run_forward(moe_block, X, is_grouped_gemm=False)
-    grouped_result = run_forward(grouped_gemm_block, X, is_grouped_gemm=True)
-    fused_result = run_forward(fused_gemm_block, X, is_grouped_gemm=True)
+    ref_result = run_forward(moe_block, X, is_grouped_gemm = False)
+    grouped_result = run_forward(grouped_gemm_block, X, is_grouped_gemm = True)
+    fused_result = run_forward(fused_gemm_block, X, is_grouped_gemm = True)
 
     with annotated_context(
         "Testing forward pass",
-        epilogue="Passed forward tests!",
-        char="=",
-        num_chars=100,
+        epilogue = "Passed forward tests!",
+        char = "=",
+        num_chars = 100,
     ):
         # Sanity checks
 
         with annotated_context(
             "Checking HF vs torch grouped gemm MoE forward outputs..."
         ):
-            check_fwd(ref_result, grouped_result, atol, rtol, verbose=False)
+            check_fwd(ref_result, grouped_result, atol, rtol, verbose = False)
 
         with annotated_context(
             "Checking torch grouped gemm MoE vs fused grouped gemm MoE forward outputs..."
@@ -195,42 +195,42 @@ def test_qwen3_moe(
             check_grouped_gemm_results(
                 grouped_result.grouped_gemm_result,
                 fused_result.grouped_gemm_result,
-                permute_y=permute_y,
-                atol=atol,
-                rtol=rtol,
-                verbose=False,
+                permute_y = permute_y,
+                atol = atol,
+                rtol = rtol,
+                verbose = False,
             )
         # Actual test
         with annotated_context(
             "Checking HF vs fused grouped gemm MoE forward outputs..."
         ):
-            check_fwd(ref_result, fused_result, atol, rtol, verbose=True)
+            check_fwd(ref_result, fused_result, atol, rtol, verbose = True)
 
     # Backward
     grad_output = torch.randn_like(ref_result.output)
     ref_backward_result = run_backward(
-        moe_block, grad_output, output=ref_result.output, X=ref_result.X
+        moe_block, grad_output, output = ref_result.output, X = ref_result.X
     )
     grouped_backward_result = run_backward(
         grouped_gemm_block,
         grad_output,
-        output=grouped_result.output,
-        X=grouped_result.X,
+        output = grouped_result.output,
+        X = grouped_result.X,
     )
     fused_backward_result = run_backward(
-        fused_gemm_block, grad_output, output=fused_result.output, X=fused_result.X
+        fused_gemm_block, grad_output, output = fused_result.output, X = fused_result.X
     )
 
     with annotated_context(
         "Testing backward pass",
-        epilogue="Passed backward tests!",
-        char="=",
-        num_chars=100,
+        epilogue = "Passed backward tests!",
+        char = "=",
+        num_chars = 100,
     ):
         # Sanity checks
         with annotated_context("Checking HF vs torch grouped gemm MoE grads..."):
             check_grads(
-                ref_backward_result, grouped_backward_result, atol, rtol, verbose=False
+                ref_backward_result, grouped_backward_result, atol, rtol, verbose = False
             )
         with annotated_context(
             "Checking torch grouped gemm MoE vs fused grouped gemm MoE grads..."
@@ -240,25 +240,25 @@ def test_qwen3_moe(
                 fused_backward_result,
                 atol,
                 rtol,
-                verbose=False,
+                verbose = False,
             )
 
         # Actual test
         with annotated_context("Checking HF vs fused grouped gemm MoE grads..."):
             check_grads(
-                ref_backward_result, fused_backward_result, atol, rtol, verbose=True
+                ref_backward_result, fused_backward_result, atol, rtol, verbose = True
             )
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--seqlen", type=int, default=1024)
+    parser.add_argument("--seqlen", type = int, default = 1024)
     parser.add_argument(
-        "--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16"
+        "--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
     )
-    parser.add_argument("--permute_x", action="store_true")
-    parser.add_argument("--permute_y", action="store_true")
-    parser.add_argument("--autotune", action="store_true")
+    parser.add_argument("--permute_x", action = "store_true")
+    parser.add_argument("--permute_y", action = "store_true")
+    parser.add_argument("--autotune", action = "store_true")
     args = parser.parse_args()
     args.dtype = getattr(torch, args.dtype)
     args_dict = vars(args)
diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py
index ec45c6033..accbed11b 100644
--- a/unsloth/kernels/rms_layernorm.py
+++ b/unsloth/kernels/rms_layernorm.py
@@ -17,20 +17,25 @@
 import torch
 from .utils import calculate_settings, torch_gpu_device
 
+
 @triton.jit
 def _rms_layernorm_forward(
-    Y, Y_row_stride : tl.constexpr,
-    X, X_row_stride : tl.constexpr,
-    W, W_row_stride : tl.constexpr,
-    r, r_row_stride : tl.constexpr,
-    n_cols     : tl.constexpr,
-    eps        : tl.constexpr,
-    BLOCK_SIZE : tl.constexpr,
+    Y,
+    Y_row_stride: tl.constexpr,
+    X,
+    X_row_stride: tl.constexpr,
+    W,
+    W_row_stride: tl.constexpr,
+    r,
+    r_row_stride: tl.constexpr,
+    n_cols: tl.constexpr,
+    eps: tl.constexpr,
+    BLOCK_SIZE: tl.constexpr,
 ):
     """
-        Fast RMS Layernorm kernel
-        Inspiration from a Triton tutorial:
-        https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
+    Fast RMS Layernorm kernel
+    Inspiration from a Triton tutorial:
+    https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
     """
     row_idx = tl.program_id(0)
     col_offsets = tl.arange(0, BLOCK_SIZE)
@@ -41,61 +46,70 @@ def _rms_layernorm_forward(
     r += row_idx * r_row_stride
 
     X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
-    W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
+    W_row = tl.load(W + col_offsets, mask = mask, other = 0)  # .to(tl.float32)
 
     row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
     inv_var = tl.math.rsqrt(row_var + eps)
     tl.store(r, inv_var)
     normed = X_row * inv_var
-    normed = normed.to(W_row.dtype) # Exact copy from HF
+    normed = normed.to(W_row.dtype)  # Exact copy from HF
     output = normed * W_row
     tl.store(Y + col_offsets, output, mask = mask)
-pass
 
 
 def _rms_layernorm_backward(
-    dY, dY_row_stride : tl.constexpr,
-    dX, dX_row_stride : tl.constexpr,
-    X,   X_row_stride : tl.constexpr,
-    W,   W_row_stride : tl.constexpr,
-    r,   r_row_stride : tl.constexpr,
+    dY,
+    dY_row_stride: tl.constexpr,
+    dX,
+    dX_row_stride: tl.constexpr,
+    X,
+    X_row_stride: tl.constexpr,
+    W,
+    W_row_stride: tl.constexpr,
+    r,
+    r_row_stride: tl.constexpr,
     # dW, dW_row_stride,
-    n_cols     : tl.constexpr,
-    eps        : tl.constexpr,
-    GEMMA      : tl.constexpr,
-    BLOCK_SIZE : tl.constexpr,
+    n_cols: tl.constexpr,
+    eps: tl.constexpr,
+    GEMMA: tl.constexpr,
+    BLOCK_SIZE: tl.constexpr,
 ):
     """
-        Fast RMS Layernorm kernel for the backward pass
-        Inspiration from a Triton tutorial:
-        https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
+    Fast RMS Layernorm kernel for the backward pass
+    Inspiration from a Triton tutorial:
+    https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
     """
     row_idx = tl.program_id(0)
     col_offsets = tl.arange(0, BLOCK_SIZE)
     mask = col_offsets < n_cols
 
     dY += row_idx * dY_row_stride
-    X  += row_idx *  X_row_stride
-    r  += row_idx *  r_row_stride
+    X += row_idx * X_row_stride
+    r += row_idx * r_row_stride
 
-    if GEMMA: dX += row_idx * dY_row_stride
-    else:     dX = dY
+    if GEMMA:
+        dX += row_idx * dY_row_stride
+    else:
+        dX = dY
 
     dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
-    X_row  = tl.load(X  + col_offsets, mask = mask, other = 0).to(tl.float32)
-    W_row  = tl.load(W  + col_offsets, mask = mask, other = 0).to(tl.float32)
+    X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+    W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
 
     # Get saved row variance
     inv_var = tl.load(r).to(tl.float32)
     normed = X_row * inv_var
 
-    if GEMMA: dY_W = dY_row * (W_row + 1.0)
-    else:     dY_W = dY_row * W_row
+    if GEMMA:
+        dY_W = dY_row * (W_row + 1.0)
+    else:
+        dY_W = dY_row * W_row
 
     rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
-    output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
+    output = inv_var / n_cols * (n_cols * dY_W - normed * rowsum_dY_normed)
     tl.store(dX + col_offsets, output, mask = mask)
-pass
+
+
 _rms_layernorm_backward = triton.jit(_rms_layernorm_backward)
 _rms_layernorm_backward = triton.heuristics(
     {
@@ -106,13 +120,17 @@ def _rms_layernorm_backward(
 
 @triton.jit
 def _gemma_rms_layernorm_forward(
-    Y, Y_row_stride : tl.constexpr,
-    X, X_row_stride : tl.constexpr,
-    W, W_row_stride : tl.constexpr,
-    r, r_row_stride : tl.constexpr,
-    n_cols     : tl.constexpr,
-    eps        : tl.constexpr,
-    BLOCK_SIZE : tl.constexpr,
+    Y,
+    Y_row_stride: tl.constexpr,
+    X,
+    X_row_stride: tl.constexpr,
+    W,
+    W_row_stride: tl.constexpr,
+    r,
+    r_row_stride: tl.constexpr,
+    n_cols: tl.constexpr,
+    eps: tl.constexpr,
+    BLOCK_SIZE: tl.constexpr,
 ):
     # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
     # and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
@@ -135,20 +153,19 @@ def _gemma_rms_layernorm_forward(
     output = normed * (W_row + 1.0)
 
     tl.store(Y + col_offsets, output, mask = mask)
-pass
 
 
 class Fast_RMS_Layernorm(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool = False):
+    def forward(ctx, X: torch.Tensor, W: torch.Tensor, eps: float, gemma: bool = False):
         shape = X.shape
-        dim : int = shape[-1]
+        dim: int = shape[-1]
         X = X.view(-1, dim)
-        n_rows : int
-        n_cols : int
+        n_rows: int
+        n_cols: int
         n_rows, n_cols = X.shape
-        BLOCK_SIZE : int
-        num_warps  : int
+        BLOCK_SIZE: int
+        num_warps: int
         BLOCK_SIZE, num_warps = calculate_settings(n_cols)
         device = X.device
 
@@ -158,119 +175,139 @@ def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool =
         fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
         with torch_gpu_device(device):
             fx[(n_rows,)](
-                Y, Y.stride(0),
-                X, X.stride(0),
-                W, W.stride(0),
-                r, r.stride(0),
-                n_cols, eps,
+                Y,
+                Y.stride(0),
+                X,
+                X.stride(0),
+                W,
+                W.stride(0),
+                r,
+                r.stride(0),
+                n_cols,
+                eps,
                 BLOCK_SIZE = BLOCK_SIZE,
-                num_warps  = num_warps,
+                num_warps = num_warps,
             )
         ctx.eps = eps
         ctx.BLOCK_SIZE = BLOCK_SIZE
-        ctx.num_warps  = num_warps
+        ctx.num_warps = num_warps
         ctx.GEMMA = gemma
         ctx.save_for_backward(X, W, r)
         return Y.view(*shape)
-    pass
 
     @staticmethod
-    def backward(ctx, dY : torch.Tensor):
+    def backward(ctx, dY: torch.Tensor):
         shape = dY.shape
-        dim : int = shape[-1]
+        dim: int = shape[-1]
         dY = dY.view(-1, dim)
         X, W, r = ctx.saved_tensors
-        n_rows : int
-        n_cols : int
+        n_rows: int
+        n_cols: int
         n_rows, n_cols = dY.shape
         # dW = X
         dX = torch.empty_like(dY) if ctx.GEMMA else dY
 
         with torch_gpu_device(dY.device):
             _rms_layernorm_backward[(n_rows,)](
-                dY, dY.stride(0),
-                dX, dX.stride(0),
-                X,  X .stride(0),
-                W,  W .stride(0),
-                r,  r .stride(0),
+                dY,
+                dY.stride(0),
+                dX,
+                dX.stride(0),
+                X,
+                X.stride(0),
+                W,
+                W.stride(0),
+                r,
+                r.stride(0),
                 # dW, dW.stride(0),
-                n_cols, ctx.eps,
-                GEMMA      = ctx.GEMMA,
+                n_cols,
+                ctx.eps,
+                GEMMA = ctx.GEMMA,
                 BLOCK_SIZE = ctx.BLOCK_SIZE,
-                num_warps  = ctx.num_warps,
+                num_warps = ctx.num_warps,
             )
         dX = dX.view(*shape)
         return dX, None, None, None
-    pass
-pass
 
 
 # [TODO] Unsure why RMS Layernorm is not torch.compiling properly
 @torch.compiler.disable
-def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False):
-    W : torch.Tensor = layernorm.weight
-    eps : float = layernorm.variance_epsilon if \
-        hasattr(layernorm, "variance_epsilon") \
+def fast_rms_layernorm(layernorm, X: torch.Tensor, gemma: bool = False):
+    W: torch.Tensor = layernorm.weight
+    eps: float = (
+        layernorm.variance_epsilon
+        if hasattr(layernorm, "variance_epsilon")
         else layernorm.eps
+    )
     out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
     return out
-pass
 
 
 from transformers.models.llama.modeling_llama import LlamaRMSNorm
+
+
 class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
     def forward(self, X):
         return fast_rms_layernorm(self, X, gemma = False)
-    pass
-pass
+
 
 try:
     from transformers.models.mllama.modeling_mllama import MllamaTextRMSNorm
+
     class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
         def forward(self, X):
             return fast_rms_layernorm(self, X, gemma = False)
-        pass
-    pass
+
+
 except:
     pass
-pass
+
 
 def patch_rms_layernorm():
     import transformers.models.llama.modeling_llama
+
     transformers.models.llama.modeling_llama.LlamaRMSNorm = Unsloth_LlamaRMSNorm
     try:
         import transformers.models.mllama.modeling_mllama
-        transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = Unsloth_MllamaTextRMSNorm
+
+        transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = (
+            Unsloth_MllamaTextRMSNorm
+        )
     except:
         pass
     return
-pass
 
 
 def unpatch_rms_layernorm():
     import transformers.models.llama.modeling_llama
+
     transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
     try:
         import transformers.models.mllama.modeling_mllama
+
         transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = MllamaTextRMSNorm
     except:
         pass
     return
-pass
 
 
 def test_rms_layernorm(
-    dim = 1024, eps = 1e-5, dtype = torch.float16,
-    bsz = 21, random_state = 3407, seqlen = 3341,
+    dim = 1024,
+    eps = 1e-5,
+    dtype = torch.float16,
+    bsz = 21,
+    random_state = 3407,
+    seqlen = 3341,
 ):
     from transformers.models.llama.modeling_llama import LlamaRMSNorm
+
     layernorm = LlamaRMSNorm((dim,), eps = eps).to("cuda")
     torch.cuda.manual_seed(random_state)
     torch.manual_seed(random_state)
     torch.nn.init.uniform_(layernorm.weight)
     X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
     XX = X.clone()
-    X .requires_grad_(True)
+    X.requires_grad_(True)
     XX.requires_grad_(True)
     Y = layernorm(X)
     YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
@@ -279,8 +316,7 @@ def test_rms_layernorm(
     # from unsloth.kernels import fast_rms_layernorm
     Y = fast_rms_layernorm(layernorm, XX)
     Y.backward(YY)
-    assert(torch.amax(correct_grad - XX.grad).item() <= 0.05)
-pass
+    assert torch.amax(correct_grad - XX.grad).item() <= 0.05
 
 
 def testing_suite_layernorm():
@@ -297,9 +333,3 @@ def testing_suite_layernorm():
                             random_state = random_state,
                             seqlen = seqlen,
                         )
-                    pass
-                pass
-            pass
-        pass
-    pass
-pass
diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py
index a5e585387..8eecec10c 100644
--- a/unsloth/kernels/rope_embedding.py
+++ b/unsloth/kernels/rope_embedding.py
@@ -16,39 +16,55 @@
 import triton.language as tl
 import torch
 from .utils import calculate_settings, torch_gpu_device, torch_device_stream
-ROPE_GROUP_SIZE : int = 4
+
+ROPE_GROUP_SIZE: int = 4
+
 
 def _rope_embedding(
-    Q,     Q_row_stride: tl.constexpr,
-    cos, cos_row_stride: tl.constexpr,
-    sin, sin_row_stride: tl.constexpr,
+    Q,
+    Q_row_stride: tl.constexpr,
+    cos,
+    cos_row_stride: tl.constexpr,
+    sin,
+    sin_row_stride: tl.constexpr,
     seqlen,
-    head_dim      : tl.constexpr,
-    n_heads       : tl.constexpr,
-    BACKWARD_PASS : tl.constexpr,
-    BLOCK_SIZE    : tl.constexpr,
+    head_dim: tl.constexpr,
+    n_heads: tl.constexpr,
+    BACKWARD_PASS: tl.constexpr,
+    BLOCK_SIZE: tl.constexpr,
 ):
     """
-        Calculates the RoPE Embedding quickly
-        RoPE is Q * cos + rotate_half(Q) * sin
-        See our blog post for more info
+    Calculates the RoPE Embedding quickly
+    RoPE is Q * cos + rotate_half(Q) * sin
+    See our blog post for more info
     """
     ROPE_GROUP_SIZE = 4
-    row_position  = tl.program_id(0)
+    row_position = tl.program_id(0)
     group_head_position = tl.program_id(1)
-    col_offsets  = tl.arange(0, BLOCK_SIZE)
+    col_offsets = tl.arange(0, BLOCK_SIZE)
     half_head_dim = head_dim // 2
     mask = col_offsets < half_head_dim
 
-    sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
-                   half_head_dim*0 + col_offsets, mask = mask, other = 0)
-    cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
-                   half_head_dim*0 + col_offsets, mask = mask, other = 0)
+    sin1 = tl.load(
+        sin
+        + (row_position % seqlen) * sin_row_stride
+        + half_head_dim * 0
+        + col_offsets,
+        mask = mask,
+        other = 0,
+    )
+    cos1 = tl.load(
+        cos
+        + (row_position % seqlen) * cos_row_stride
+        + half_head_dim * 0
+        + col_offsets,
+        mask = mask,
+        other = 0,
+    )
 
     if BACKWARD_PASS:
         # See our blog post for more info.
         sin1 = -sin1
-    pass
 
     # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
     head_start = group_head_position * ROPE_GROUP_SIZE
@@ -57,16 +73,18 @@ def _rope_embedding(
     # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
     for k in range(head_start, head_end):
         offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
-        offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
+        offs_q2 = (
+            row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
+        )
 
         # For Gemma - sometimes RoPE must be done in float32 and not bfloat16
         Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
         Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
 
-        tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
-        tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
-    pass
-pass
+        tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask = mask)
+        tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask = mask)
+
+
 _rope_embedding = triton.jit(_rope_embedding)
 _rope_embedding = triton.heuristics(
     {
@@ -79,76 +97,97 @@ class Fast_RoPE_Embedding(torch.autograd.Function):
     @staticmethod
     def forward(ctx, Q, cos, sin):
         cos, sin = cos.squeeze(), sin.squeeze()
-        batch    : int
-        seq_len  : int
-        n_heads  : int
-        head_dim : int
+        batch: int
+        seq_len: int
+        n_heads: int
+        head_dim: int
         batch, seq_len, n_heads, head_dim = Q.shape
-        Q = Q.view(batch*seq_len, n_heads*head_dim)
-        n_rows : int
-        n_cols : int
+        Q = Q.view(batch * seq_len, n_heads * head_dim)
+        n_rows: int
+        n_cols: int
         n_rows, n_cols = Q.shape
-        assert(seq_len <= cos.shape[0])
+        assert seq_len <= cos.shape[0]
 
         # [TODO] Changing blocksize to head_dim//2 seems to have
         # some concurrency / un-deterministic issues.
-        BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
+        BLOCK_SIZE, num_warps = calculate_settings(head_dim // 2)  # (head_dim//2)
 
         # group_size = 4 # 4 or 8, too large group_size can hurt performance.
-        div : int
-        mod : int
+        div: int
+        mod: int
         div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
-        n_groups : int = div + (mod != 0)
+        n_groups: int = div + (mod != 0)
 
         with torch_gpu_device(Q.device):
-            _rope_embedding[(n_rows, n_groups, )](
-                  Q,   Q.stride(0),
-                cos, cos.stride(0),
-                sin, sin.stride(0),
+            _rope_embedding[
+                (
+                    n_rows,
+                    n_groups,
+                )
+            ](
+                Q,
+                Q.stride(0),
+                cos,
+                cos.stride(0),
+                sin,
+                sin.stride(0),
                 seq_len,
-                head_dim, n_heads,
+                head_dim,
+                n_heads,
                 BACKWARD_PASS = False,
                 BLOCK_SIZE = BLOCK_SIZE,
-                num_warps  = num_warps,
+                num_warps = num_warps,
             )
         ctx.BLOCK_SIZE = BLOCK_SIZE
-        ctx.num_warps  = num_warps
+        ctx.num_warps = num_warps
         ctx.n_groups = n_groups
         ctx.cos = cos
         ctx.sin = sin
         return Q.view(batch, seq_len, n_heads, head_dim)
-    pass
 
     @staticmethod
     def backward(ctx, dY):
-        batch    : int
-        seq_len  : int
-        n_heads  : int
-        head_dim : int
+        batch: int
+        seq_len: int
+        n_heads: int
+        head_dim: int
         batch, seq_len, n_heads, head_dim = dY.shape
-        dY = dY.reshape(batch*seq_len, n_heads*head_dim)
+        dY = dY.reshape(batch * seq_len, n_heads * head_dim)
         # Must be reshape not view
-        n_rows : int
-        n_cols : int
+        n_rows: int
+        n_cols: int
         n_rows, n_cols = dY.shape
 
         cos = ctx.cos
         sin = ctx.sin
 
         with torch_gpu_device(dY.device):
-            _rope_embedding[(n_rows, ctx.n_groups, )](
-                dY,  dY .stride(0),
-                cos, cos.stride(0),
-                sin, sin.stride(0),
-                seq_len, head_dim, n_heads,
+            _rope_embedding[
+                (
+                    n_rows,
+                    ctx.n_groups,
+                )
+            ](
+                dY,
+                dY.stride(0),
+                cos,
+                cos.stride(0),
+                sin,
+                sin.stride(0),
+                seq_len,
+                head_dim,
+                n_heads,
                 BACKWARD_PASS = True,
                 BLOCK_SIZE = ctx.BLOCK_SIZE,
-                num_warps  = ctx.num_warps,
+                num_warps = ctx.num_warps,
             )
         dY = dY.view(batch, seq_len, n_heads, head_dim)
-        return dY, None, None,
-    pass
-pass
+        return (
+            dY,
+            None,
+            None,
+        )
+
 
 # [TODO] Unsure why RoPE Embedding is not torch.compiling properly
 @torch.compiler.disable
@@ -158,7 +197,6 @@ def fast_rope_embedding(Q, K, cos, sin):
     # synchronize before cat to avoid race condition
     torch_device_stream(Q.device).synchronize()
     return Q, K
-pass
 
 
 class Slow_RoPE_Embedding(torch.autograd.Function):
@@ -172,7 +210,7 @@ def forward(ctx, Q, cos, sin, position_ids):
             sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
 
         # Q * cos + rotate_half(Q) * sin
-        half = Q.shape[-1]//2
+        half = Q.shape[-1] // 2
         RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
         Q *= cos
         Q.addcmul_(RH_Q, sin)
@@ -180,21 +218,18 @@ def forward(ctx, Q, cos, sin, position_ids):
         # Q += RH_Q
         ctx.save_for_backward(cos, sin)
         return Q
-    pass
 
     @staticmethod
     def backward(ctx, dY):
         cos, sin = ctx.saved_tensors
         # Q * cos + rotate_half.T(Q) * sin
-        half = dY.shape[-1]//2
+        half = dY.shape[-1] // 2
         RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
         dY *= cos
         dY.addcmul_(RH_dY, sin)
         # RH_dY *= sin
         # dY += RH_dY
         return dY, None, None, None
-    pass
-pass
 
 
 def inplace_rope_embedding(Q, K, cos, sin, position_ids):
@@ -202,4 +237,3 @@ def inplace_rope_embedding(Q, K, cos, sin, position_ids):
     K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
     torch_device_stream(Q.device).synchronize()
     return Q, K
-pass
diff --git a/unsloth/kernels/swiglu.py b/unsloth/kernels/swiglu.py
index c1d5e3128..81e62660b 100644
--- a/unsloth/kernels/swiglu.py
+++ b/unsloth/kernels/swiglu.py
@@ -19,38 +19,54 @@
 
 
 @triton.jit
-def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
+def _fg_kernel(
+    e,
+    g,
+    h,
+    n_elements,
+    BLOCK_SIZE: tl.constexpr,
+):
     block_idx = tl.program_id(0)
-    offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     mask = offsets < n_elements
 
     e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
-    g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
+    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)
 
     # f = e * sigmoid(e)
-    f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
-    f_row = f_row.to(g_row.dtype) # Exact copy from HF
+    f_row = e_row * tl.sigmoid(e_row)  # e_row / (1 + tl.exp(-e_row))
+    f_row = f_row.to(g_row.dtype)  # Exact copy from HF
     # h = f * g
     h_row = f_row * g_row
 
     # Store h
     tl.store(h + offsets, h_row, mask = mask)
-pass
 
 
 def swiglu_fg_kernel(e, g):
     batch, seq_len, hd = e.shape
     n_elements = e.numel()
     h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)
-    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
     with torch_gpu_device(e.device):
-        _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
+        _fg_kernel[grid](
+            e,
+            g,
+            h,
+            n_elements,
+            BLOCK_SIZE = 1024,
+        )
     return h
-pass
 
 
 @triton.jit
-def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
+def _DWf_DW_dfg_kernel(
+    DW,
+    e,
+    g,
+    n_elements,
+    BLOCK_SIZE: tl.constexpr,
+):
     """
     e = e.float()
     se = 1.0 / (1.0 + torch.exp(-e))
@@ -61,21 +77,21 @@ def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
     de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
     """
     block_idx = tl.program_id(0)
-    offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     mask = offsets < n_elements
 
-    DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
-    e_row  = tl.load(e  + offsets, mask = mask, other = 0).to(tl.float32)
-    g_row  = tl.load(g  + offsets, mask = mask, other = 0)#.to(tl.float32)
+    DW_row = tl.load(DW + offsets, mask = mask, other = 0)  # .to(tl.float32)
+    e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+    g_row = tl.load(g + offsets, mask = mask, other = 0)  # .to(tl.float32)
 
     # e = e.float()
     # se = 1.0 / (1.0 + torch.exp(-e))
-    se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
+    se_row = tl.sigmoid(e_row)  # 1.0 / (1.0 + tl.exp(-e_row))
     # f = (se * e).to(dtype)
     f_row = se_row * e_row
     f_row = f_row.to(DW_row.dtype)
     # h = f * g
-    h_row  =  f_row * g_row
+    h_row = f_row * g_row
     # df = DW * f
     df_row = DW_row * f_row
     # dg = DW * g
@@ -85,17 +101,21 @@ def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
     de_row = de_row.to(DW_row.dtype)
 
     # Store derivatives in buffers
-    tl.store(DW + offsets, h_row,  mask = mask) # h  = f * g
-    tl.store(e  + offsets, df_row, mask = mask) # df = DW * f
-    tl.store(g  + offsets, de_row, mask = mask) # de
-pass
+    tl.store(DW + offsets, h_row, mask = mask)  # h  = f * g
+    tl.store(e + offsets, df_row, mask = mask)  # df = DW * f
+    tl.store(g + offsets, de_row, mask = mask)  # de
 
 
 def swiglu_DWf_DW_dfg_kernel(DW, e, g):
     batch_seq_len, hd = e.shape
     n_elements = e.numel()
-    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
     with torch_gpu_device(e.device):
-        _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
+        _DWf_DW_dfg_kernel[grid](
+            DW,
+            e,
+            g,
+            n_elements,
+            BLOCK_SIZE = 1024,
+        )
     return DW, e, g
-pass
diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py
index 820ac1678..1e3be4d5a 100644
--- a/unsloth/kernels/utils.py
+++ b/unsloth/kernels/utils.py
@@ -14,7 +14,8 @@
 
 import triton
 import ctypes
-MAX_FUSED_SIZE : int = 65536
+
+MAX_FUSED_SIZE: int = 65536
 next_power_of_2 = triton.next_power_of_2
 import functools
 from typing import Optional
@@ -32,11 +33,14 @@
 
 # torch.cuda.amp.custom_fwd is deprecated >= 2.4
 import torch
+
 torch_Tensor = torch.Tensor
 from packaging.version import Version
 
 if DEVICE_TYPE == "xpu" and Version(torch.__version__) < Version("2.6.0"):
-    raise RuntimeError("Intel xpu currently supports unsloth with torch.version >= 2.6.0")
+    raise RuntimeError(
+        "Intel xpu currently supports unsloth with torch.version >= 2.6.0"
+    )
 
 if Version(torch.__version__) < Version("2.4.0"):
     torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
@@ -44,7 +48,6 @@
 else:
     torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
     torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
-pass
 
 if DEVICE_TYPE == "xpu":
     torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
@@ -55,42 +58,58 @@
 from packaging.version import Version
 import triton
 import triton.language as tl
+
 if Version(triton.__version__) >= Version("3.0.0"):
     if DEVICE_TYPE == "xpu":
         triton_tanh = tl.extra.intel.libdevice.tanh
     else:
         from triton.language.extra import libdevice
+
         triton_tanh = libdevice.tanh
     triton_cast = tl.cast
 else:
     triton_tanh = tl.math.tanh
+
     # No casting in old Triton versions
     @triton.jit
     def triton_cast(x, dtype):
         return x.to(dtype)
-    pass
-pass
 
 
 @functools.lru_cache(1)
 def is_cdna():
-    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942')
+    return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
+        "gfx940",
+        "gfx941",
+        "gfx942",
+    )
 
 
-def calculate_settings(n : int) -> (int, int,):
-    BLOCK_SIZE : int = next_power_of_2(n)
+def calculate_settings(
+    n: int,
+) -> (
+    int,
+    int,
+):
+    BLOCK_SIZE: int = next_power_of_2(n)
     if BLOCK_SIZE > MAX_FUSED_SIZE:
-        raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
-                           f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
-    num_warps : int = 4
-    if   BLOCK_SIZE >= 32768: num_warps = 32
-    elif BLOCK_SIZE >=  8192: num_warps = 16
-    elif BLOCK_SIZE >=  2048: num_warps = 8
+        raise RuntimeError(
+            f"Cannot launch Triton kernel since n = {n} exceeds "
+            f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}."
+        )
+    num_warps: int = 4
+    if BLOCK_SIZE >= 32768:
+        num_warps = 32
+    elif BLOCK_SIZE >= 8192:
+        num_warps = 16
+    elif BLOCK_SIZE >= 2048:
+        num_warps = 8
     return BLOCK_SIZE, num_warps
-pass
+
 
 HAS_CUDA_STREAM = False
 import bitsandbytes as bnb
+
 # https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
 HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
 get_ptr = bnb.functional.get_ptr
@@ -105,8 +124,10 @@ def calculate_settings(n : int) -> (int, int,):
         torch_gpu_device = torch.xpu.device
 else:
     from contextlib import nullcontext
-    def torch_gpu_device(device): return nullcontext()
-pass
+
+    def torch_gpu_device(device):
+        return nullcontext()
+
 
 # INTEL GPU Specific Logic
 if DEVICE_TYPE == "xpu":
@@ -116,9 +137,10 @@ def torch_gpu_device(device): return nullcontext()
     _gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream
 
 c_void_p = ctypes.c_void_p
+
+
 def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
     return c_void_p(_gpu_getCurrentRawStream(tensor.device.index))
-pass
 
 
 # Get array of CUDA streams and other buffers
@@ -130,10 +152,12 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
 # INTEL GPU Specific Logic
 if DEVICE_TYPE == "xpu":
     _XPU_STREAMS = {
-        (index := torch.xpu.device(i).idx) : ctypes.c_void_p(torch._C._xpu_getCurrentRawStream(index))
+        (index := torch.xpu.device(i).idx): ctypes.c_void_p(
+            torch._C._xpu_getCurrentRawStream(index)
+        )
         for i in range(DEVICE_COUNT)
     }
-    XPU_STREAMS    = [None] * (max(_XPU_STREAMS.keys()) + 1)
+    XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1)
     WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
     ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
     for k, v in _XPU_STREAMS.items():
@@ -143,23 +167,25 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
 else:
     # NVIDIA GPU Default Logic
     _CUDA_STREAMS = {
-        (index := torch.cuda.device(i).idx) : ctypes.c_void_p(torch._C._cuda_getCurrentRawStream(index))
+        (index := torch.cuda.device(i).idx): ctypes.c_void_p(
+            torch._C._cuda_getCurrentRawStream(index)
+        )
         for i in range(DEVICE_COUNT)
     }
-    CUDA_STREAMS   = [None] * (max(_CUDA_STREAMS.keys()) + 1)
+    CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
     WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
     ABSMAX_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
-    for k, v in _CUDA_STREAMS.items(): CUDA_STREAMS[k] = v
+    for k, v in _CUDA_STREAMS.items():
+        CUDA_STREAMS[k] = v
     CUDA_STREAMS = tuple(CUDA_STREAMS)
     del _CUDA_STREAMS
-pass
 
 # Bitsandbytes operations
-ctypes_c_int   = ctypes.c_int
+ctypes_c_int = ctypes.c_int
 ctypes_c_int32 = ctypes.c_int32
-cdequantize_blockwise_fp32      = bnb.functional.lib.cdequantize_blockwise_fp32
-cdequantize_blockwise_fp16_nf4  = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
-cdequantize_blockwise_bf16_nf4  = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
+cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
+cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
+cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
 
 if DEVICE_TYPE == "xpu":
     # https://github.com/bitsandbytes-foundation/bitsandbytes/blob/c3b8de268fdb55a88f92feada23fc811a1e6877a/bitsandbytes/backends/xpu/ops.py#L115
@@ -171,18 +197,23 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
     cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
 
 
-torch_device_stream = torch.xpu.current_stream if DEVICE_TYPE == "xpu" else torch.cuda.current_stream
+torch_device_stream = (
+    torch.xpu.current_stream if DEVICE_TYPE == "xpu" else torch.cuda.current_stream
+)
 
 torch_mm = torch.mm
 torch_mv = torch.mv
 torch_matmul = torch.matmul
-torch_addmm  = torch.addmm
-torch_empty  = torch.empty
-torch_float32  = torch.float32
-torch_float16  = torch.float16
+torch_addmm = torch.addmm
+torch_empty = torch.empty
+torch_float32 = torch.float32
+torch_float16 = torch.float16
 torch_bfloat16 = torch.bfloat16
 
-def QUANT_STATE(W): return getattr(W, "quant_state", None)
+
+def QUANT_STATE(W):
+    return getattr(W, "quant_state", None)
+
 
 def get_lora_parameters(proj):
     """
@@ -190,7 +221,9 @@ def get_lora_parameters(proj):
     If QAT is enabled, additionally fake quantize the base layer and lora weights.
     """
     # For DPO or disabled adapters
-    base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
+    base_layer = getattr(
+        proj, "base_layer", proj
+    )  # (proj.base_layer if hasattr(proj, "base_layer") else proj)
     W = base_layer.weight
 
     # Optionally apply fake quantization to base layer weights for QAT
@@ -214,10 +247,10 @@ def get_lora_parameters(proj):
     # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
     if getattr(proj, "disable_adapters", True) or proj.merged:
         return W, W_quant, None, None, None
-    pass
 
     adapter = getattr(proj, "active_adapters", None)
-    if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
+    if adapter is None:
+        adapter = getattr(proj, "active_adapter", ("default"))
     adapter = adapter[0]
 
     # Optionally apply fake quantization to lora weights for QAT
@@ -241,12 +274,13 @@ def get_lora_parameters(proj):
         B,
         proj.scaling[adapter],
     )
-pass
 
 
 def get_lora_parameters_bias(proj):
     # For DPO or disabled adapters
-    base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
+    base_layer = getattr(
+        proj, "base_layer", proj
+    )  # (proj.base_layer if hasattr(proj, "base_layer") else proj)
     W = base_layer.weight
 
     # Get quant state for 4bit or FP8
@@ -259,7 +293,6 @@ def get_lora_parameters_bias(proj):
     # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
     if getattr(proj, "disable_adapters", True) or proj.merged:
         return W, W_quant, None, None, None, base_layer.bias
-    pass
 
     if getattr(base_layer, "quant_method", None) == "fp8":
         # we need to somehow store and pass this information :)
@@ -267,21 +300,23 @@ def get_lora_parameters_bias(proj):
         W_quant.block_size = W.block_size
 
     adapter = getattr(proj, "active_adapters", None)
-    if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
+    if adapter is None:
+        adapter = getattr(proj, "active_adapter", ("default"))
     adapter = adapter[0]
 
     return (
         W,
         W_quant,
-        proj.lora_A [adapter].weight,
-        proj.lora_B [adapter].weight,
+        proj.lora_A[adapter].weight,
+        proj.lora_B[adapter].weight,
         proj.scaling[adapter],
         base_layer.bias,
     )
-pass
 
 
-def _maybe_fake_quantize_activations(X: torch.Tensor, proj: torch.nn.Module) -> torch.Tensor:
+def _maybe_fake_quantize_activations(
+    X: torch.Tensor, proj: torch.nn.Module
+) -> torch.Tensor:
     """
     If QAT is enabled, fake quantize the input activations.
     Otherwise, just return the input activations as is.
@@ -292,34 +327,35 @@ def _maybe_fake_quantize_activations(X: torch.Tensor, proj: torch.nn.Module) ->
     if activation_fake_quantizer is not None:
         X = activation_fake_quantizer(X)
     return X
-pass
 
 
 # INTEL GPU Specific Logic
 if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
+
     @torch.inference_mode
     def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
         # TODO: After adding XPU BNB support, check this function
-        if quant_state is None: return W
-        if W.dtype == torch.float8_e4m3fn: return weight_dequant(W, quant_state)
+        if quant_state is None:
+            return W
+        if W.dtype == torch.float8_e4m3fn:
+            return weight_dequant(W, quant_state)
         if type(quant_state) is not list:
             # New quant_state as a class
             # https://github.com/TimDettmers/bitsandbytes/pull/763/files
-            absmax     = quant_state.absmax
-            shape      = quant_state.shape
-            dtype      = quant_state.dtype
-            blocksize  = quant_state.blocksize
-            offset     = quant_state.offset
-            state2     = quant_state.state2
-            absmax2    = state2.absmax
-            code2      = state2.code
+            absmax = quant_state.absmax
+            shape = quant_state.shape
+            dtype = quant_state.dtype
+            blocksize = quant_state.blocksize
+            offset = quant_state.offset
+            state2 = quant_state.state2
+            absmax2 = state2.absmax
+            code2 = state2.code
             blocksize2 = state2.blocksize
         else:
             # Old quant_state as a list of lists
             absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
             offset, state2 = compressed_stats
             absmax2, code2, blocksize2, _, _, _, _ = state2
-        pass
         global XPU_STREAMS
         device = W.device
         device_index = device.index
@@ -328,74 +364,104 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
         n_elements_absmax = absmax.numel()
         # Create weight matrix
         if use_global_buffer:
-
             # Use same buffers for faster inference
-            size = shape[0]*shape[1]
+            size = shape[0] * shape[1]
             global WEIGHT_BUFFERS
             global ABSMAX_BUFFERS
             WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
             ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
             if WEIGHT_BUFFER is None:
-                WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(size, dtype = dtype, device = device, requires_grad = False)
-                ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
-
-            if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
-            if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)
+                WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
+                    size, dtype = dtype, device = device, requires_grad = False
+                )
+                ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
+                    n_elements_absmax,
+                    dtype = torch.float32,
+                    device = device,
+                    requires_grad = False,
+                )
+
+            if size > WEIGHT_BUFFER.numel():
+                WEIGHT_BUFFER.resize_(size)
+            if n_elements_absmax > ABSMAX_BUFFER.numel():
+                ABSMAX_BUFFER.resize_(n_elements_absmax)
 
             out = WEIGHT_BUFFER[:size].view(shape)
             out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
         else:
             if out is None:
-                out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
+                out = torch_empty(
+                    shape, dtype = dtype, device = device, requires_grad = False
+                )
             else:
-                assert(out.shape == shape)
-                assert(out.dtype == dtype)
-            out_absmax = torch_empty(n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False)
-        pass
+                assert out.shape == shape
+                assert out.dtype == dtype
+            out_absmax = torch_empty(
+                n_elements_absmax,
+                dtype = torch_float32,
+                device = device,
+                requires_grad = False,
+            )
 
         # NF4 dequantization of statistics
         ptr_out_absmax = get_ptr(out_absmax)
         with torch_gpu_device(device):
             cdequantize_blockwise_fp32(
-                get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
-                ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), XPU_STREAM
+                get_ptr(code2),
+                get_ptr(absmax),
+                get_ptr(absmax2),
+                ptr_out_absmax,
+                ctypes_c_int(blocksize2),
+                ctypes_c_int(n_elements_absmax),
+                XPU_STREAM,
             )
             out_absmax += offset
 
             # Dequantize W
-            fx = cdequantize_blockwise_fp16_nf4 if dtype == torch_float16 else \
-                 cdequantize_blockwise_bf16_nf4
-            fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
-               ctypes_c_int(blocksize), ctypes_c_int(out.numel()), XPU_STREAM,)
-        pass
+            fx = (
+                cdequantize_blockwise_fp16_nf4
+                if dtype == torch_float16
+                else cdequantize_blockwise_bf16_nf4
+            )
+            fx(
+                get_ptr(None),
+                get_ptr(W),
+                ptr_out_absmax,
+                get_ptr(out),
+                ctypes_c_int(blocksize),
+                ctypes_c_int(out.numel()),
+                XPU_STREAM,
+            )
         # Careful returning transposed data
-        is_transposed = (True if W.shape[0] == 1 else False)
+        is_transposed = True if W.shape[0] == 1 else False
         return out.t() if is_transposed else out
-    pass
+
 # NVIDIA GPU Default Logic
 elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
+
     @torch.inference_mode
     def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
-        if quant_state is None: return W
-        if W.dtype == torch.float8_e4m3fn: return weight_dequant(W, quant_state)
+        if quant_state is None:
+            return W
+        if W.dtype == torch.float8_e4m3fn:
+            return weight_dequant(W, quant_state)
         if type(quant_state) is not list:
             # New quant_state as a class
             # https://github.com/TimDettmers/bitsandbytes/pull/763/files
-            absmax     = quant_state.absmax
-            shape      = quant_state.shape
-            dtype      = quant_state.dtype
-            blocksize  = quant_state.blocksize
-            offset     = quant_state.offset
-            state2     = quant_state.state2
-            absmax2    = state2.absmax
-            code2      = state2.code
+            absmax = quant_state.absmax
+            shape = quant_state.shape
+            dtype = quant_state.dtype
+            blocksize = quant_state.blocksize
+            offset = quant_state.offset
+            state2 = quant_state.state2
+            absmax2 = state2.absmax
+            code2 = state2.code
             blocksize2 = state2.blocksize
         else:
             # Old quant_state as a list of lists
             absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
             offset, state2 = compressed_stats
             absmax2, code2, blocksize2, _, _, _, _ = state2
-        pass
         global CUDA_STREAMS
         device = W.device
         device_index = device.index
@@ -405,73 +471,103 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
 
         # Create weight matrix
         if use_global_buffer:
-
             # Use same buffers for faster inference
-            size = shape[0]*shape[1]
+            size = shape[0] * shape[1]
             global WEIGHT_BUFFERS
             global ABSMAX_BUFFERS
             WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
             ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
             if WEIGHT_BUFFER is None:
-                WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(size, dtype = dtype, device = device, requires_grad = False)
-                ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False)
-
-            if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
-            if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)
+                WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
+                    size, dtype = dtype, device = device, requires_grad = False
+                )
+                ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
+                    n_elements_absmax,
+                    dtype = torch_float32,
+                    device = device,
+                    requires_grad = False,
+                )
+
+            if size > WEIGHT_BUFFER.numel():
+                WEIGHT_BUFFER.resize_(size)
+            if n_elements_absmax > ABSMAX_BUFFER.numel():
+                ABSMAX_BUFFER.resize_(n_elements_absmax)
 
             out = WEIGHT_BUFFER[:size].view(shape)
             out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
         else:
             if out is None:
-                out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
+                out = torch_empty(
+                    shape, dtype = dtype, device = device, requires_grad = False
+                )
             else:
-                assert(out.shape == shape)
-                assert(out.dtype == dtype)
-            out_absmax = torch_empty(n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False)
-        pass
+                assert out.shape == shape
+                assert out.dtype == dtype
+            out_absmax = torch_empty(
+                n_elements_absmax,
+                dtype = torch_float32,
+                device = device,
+                requires_grad = False,
+            )
 
         # NF4 dequantization of statistics
         ptr_out_absmax = get_ptr(out_absmax)
         with torch_gpu_device(device):
             cdequantize_blockwise_fp32(
-                get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
-                ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM
+                get_ptr(code2),
+                get_ptr(absmax),
+                get_ptr(absmax2),
+                ptr_out_absmax,
+                ctypes_c_int(blocksize2),
+                ctypes_c_int(n_elements_absmax),
+                CUDA_STREAM,
             )
             out_absmax += offset
 
             # Dequantize W
-            fx = cdequantize_blockwise_fp16_nf4 if dtype == torch_float16 else \
-                 cdequantize_blockwise_bf16_nf4
-            fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
-               ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,)
-        pass
+            fx = (
+                cdequantize_blockwise_fp16_nf4
+                if dtype == torch_float16
+                else cdequantize_blockwise_bf16_nf4
+            )
+            fx(
+                get_ptr(None),
+                get_ptr(W),
+                ptr_out_absmax,
+                get_ptr(out),
+                ctypes_c_int(blocksize),
+                ctypes_c_int(out.numel()),
+                CUDA_STREAM,
+            )
         # Careful returning transposed data
-        is_transposed = (True if W.shape[0] == 1 else False)
+        is_transposed = True if W.shape[0] == 1 else False
         return out.t() if is_transposed else out
-    pass
+
 else:
+
     @torch.inference_mode
     def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
-        if quant_state is None: return W
-        if W.dtype == torch.float8_e4m3fn: return weight_dequant(W, quant_state)
+        if quant_state is None:
+            return W
+        if W.dtype == torch.float8_e4m3fn:
+            return weight_dequant(W, quant_state)
         if type(quant_state) is not list:
             # New quant_state as a class
             # https://github.com/TimDettmers/bitsandbytes/pull/763/files
-            absmax     = quant_state.absmax
-            shape      = quant_state.shape
-            dtype      = quant_state.dtype
-            blocksize  = quant_state.blocksize
-            offset     = quant_state.offset
-            state2     = quant_state.state2
-            absmax2    = state2.absmax
-            code2      = state2.code
+            absmax = quant_state.absmax
+            shape = quant_state.shape
+            dtype = quant_state.dtype
+            blocksize = quant_state.blocksize
+            offset = quant_state.offset
+            state2 = quant_state.state2
+            absmax2 = state2.absmax
+            code2 = state2.code
             blocksize2 = state2.blocksize
         else:
             # Old quant_state as a list of lists
             absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
             offset, state2 = compressed_stats
             absmax2, code2, blocksize2, _, _, _, _ = state2
-        pass
 
         n_elements_absmax = absmax.numel()
         device = W.device
@@ -480,34 +576,49 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
         if out is None:
             out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
         else:
-            assert(out.shape == shape)
-            assert(out.dtype == dtype)
-        out_absmax = torch_empty(n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False)
+            assert out.shape == shape
+            assert out.dtype == dtype
+        out_absmax = torch_empty(
+            n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False
+        )
 
         # Do dequantization
         ptr_out_absmax = get_ptr(out_absmax)
         cdequantize_blockwise_fp32(
-            get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
-            ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax),
+            get_ptr(code2),
+            get_ptr(absmax),
+            get_ptr(absmax2),
+            ptr_out_absmax,
+            ctypes_c_int(blocksize2),
+            ctypes_c_int(n_elements_absmax),
         )
         out_absmax += offset
 
-        fx = cdequantize_blockwise_fp16_nf4 if dtype == torch_float16 else \
-             cdequantize_blockwise_bf16_nf4
-        fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
-           ctypes_c_int(blocksize), ctypes_c_int(out.numel()),)
+        fx = (
+            cdequantize_blockwise_fp16_nf4
+            if dtype == torch_float16
+            else cdequantize_blockwise_bf16_nf4
+        )
+        fx(
+            get_ptr(None),
+            get_ptr(W),
+            ptr_out_absmax,
+            get_ptr(out),
+            ctypes_c_int(blocksize),
+            ctypes_c_int(out.numel()),
+        )
 
         # Careful returning transposed data
-        is_transposed = (True if W.shape[0] == 1 else False)
+        is_transposed = True if W.shape[0] == 1 else False
         return out.t() if is_transposed else out
-    pass
-pass
 
 
 # INTEL GPU Specific Logic
-if  DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
+if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
+
     def fast_gemv(X, W, quant_state, out = None):
-        if quant_state is None: return torch_matmul(X, W, out = out)
+        if quant_state is None:
+            return torch_matmul(X, W, out = out)
         # For fast X @ W where seq_len == 1
         # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
         _, q_len, hd = X.shape
@@ -515,21 +626,22 @@ def fast_gemv(X, W, quant_state, out = None):
 
         if type(quant_state) is not list:
             # https://github.com/TimDettmers/bitsandbytes/pull/763/files
-            absmax     = quant_state.absmax
-            shape      = quant_state.shape
-            dtype      = quant_state.dtype
-            blocksize  = quant_state.blocksize
-            stats      = quant_state.code
-            offset     = quant_state.offset
-            state2     = quant_state.state2
-            absmax2    = state2.absmax
-            code2      = state2.code
+            absmax = quant_state.absmax
+            shape = quant_state.shape
+            dtype = quant_state.dtype
+            blocksize = quant_state.blocksize
+            stats = quant_state.code
+            offset = quant_state.offset
+            state2 = quant_state.state2
+            absmax2 = state2.absmax
+            code2 = state2.code
             blocksize2 = state2.blocksize
         else:
-            absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
+            absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (
+                quant_state
+            )
             offset, state2 = compressed_stats
             absmax2, code2, blocksize2, _, _, _, _ = state2
-        pass
         global XPU_STREAMS
         device = W.device
         device_index = device.index
@@ -539,7 +651,15 @@ def fast_gemv(X, W, quant_state, out = None):
         bout = shape[0]
 
         if out is None:
-            out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
+            out = torch_empty(
+                (
+                    1,
+                    1,
+                    bout,
+                ),
+                dtype = dtype,
+                device = device,
+            )
         # else:
         #     assert(out.shape == (1, 1, bout,))
         # pass
@@ -553,7 +673,7 @@ def fast_gemv(X, W, quant_state, out = None):
         k = shape[1]
         lda = shape[0]
         ldc = shape[0]
-        ldb = (hd+1)//2
+        ldb = (hd + 1) // 2
         m = ctypes_c_int32(m)
         n = ctypes_c_int32(n)
         k = ctypes_c_int32(k)
@@ -564,25 +684,47 @@ def fast_gemv(X, W, quant_state, out = None):
         df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
         with torch_gpu_device(device):
             cdequantize_blockwise_fp32(
-                get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
-                ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), XPU_STREAM,
+                get_ptr(code2),
+                get_ptr(absmax),
+                get_ptr(absmax2),
+                get_ptr(df),
+                ctypes_c_int(blocksize2),
+                ctypes_c_int(df.numel()),
+                XPU_STREAM,
             )
             df += offset
             absmax = df
 
-            fx = cgemm_4bit_inference_naive_fp16 if dtype == torch_float16 else \
-                cgemm_4bit_inference_naive_bf16
+            fx = (
+                cgemm_4bit_inference_naive_fp16
+                if dtype == torch_float16
+                else cgemm_4bit_inference_naive_bf16
+            )
 
             blocksize = ctypes_c_int32(blocksize)
-            fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
-               lda, ldb, ldc, blocksize, XPU_STREAM,)
-        pass
+            fx(
+                m,
+                n,
+                k,
+                get_ptr(X),
+                get_ptr(W),
+                get_ptr(absmax),
+                get_ptr(stats),
+                get_ptr(out),
+                lda,
+                ldb,
+                ldc,
+                blocksize,
+                XPU_STREAM,
+            )
 
         return out
-    pass
+
 elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
+
     def fast_gemv(X, W, quant_state, out = None):
-        if quant_state is None: return torch_matmul(X, W, out = out)
+        if quant_state is None:
+            return torch_matmul(X, W, out = out)
         # For fast X @ W where seq_len == 1
         # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
         _, q_len, hd = X.shape
@@ -590,21 +732,22 @@ def fast_gemv(X, W, quant_state, out = None):
 
         if type(quant_state) is not list:
             # https://github.com/TimDettmers/bitsandbytes/pull/763/files
-            absmax     = quant_state.absmax
-            shape      = quant_state.shape
-            dtype      = quant_state.dtype
-            blocksize  = quant_state.blocksize
-            stats      = quant_state.code
-            offset     = quant_state.offset
-            state2     = quant_state.state2
-            absmax2    = state2.absmax
-            code2      = state2.code
+            absmax = quant_state.absmax
+            shape = quant_state.shape
+            dtype = quant_state.dtype
+            blocksize = quant_state.blocksize
+            stats = quant_state.code
+            offset = quant_state.offset
+            state2 = quant_state.state2
+            absmax2 = state2.absmax
+            code2 = state2.code
             blocksize2 = state2.blocksize
         else:
-            absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
+            absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (
+                quant_state
+            )
             offset, state2 = compressed_stats
             absmax2, code2, blocksize2, _, _, _, _ = state2
-        pass
         global CUDA_STREAMS
         device = W.device
         device_index = device.index
@@ -614,7 +757,15 @@ def fast_gemv(X, W, quant_state, out = None):
         bout = shape[0]
 
         if out is None:
-            out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
+            out = torch_empty(
+                (
+                    1,
+                    1,
+                    bout,
+                ),
+                dtype = dtype,
+                device = device,
+            )
         # else:
         #     assert(out.shape == (1, 1, bout,))
         # pass
@@ -624,7 +775,7 @@ def fast_gemv(X, W, quant_state, out = None):
         k = shape[1]
         lda = shape[0]
         ldc = shape[0]
-        ldb = (hd+1)//2
+        ldb = (hd + 1) // 2
         m = ctypes_c_int32(m)
         n = ctypes_c_int32(n)
         k = ctypes_c_int32(k)
@@ -635,25 +786,47 @@ def fast_gemv(X, W, quant_state, out = None):
         df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
         with torch_gpu_device(device):
             cdequantize_blockwise_fp32(
-                get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
-                ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM,
+                get_ptr(code2),
+                get_ptr(absmax),
+                get_ptr(absmax2),
+                get_ptr(df),
+                ctypes_c_int(blocksize2),
+                ctypes_c_int(df.numel()),
+                CUDA_STREAM,
             )
             df += offset
             absmax = df
 
-            fx = cgemm_4bit_inference_naive_fp16 if dtype == torch_float16 else \
-                 cgemm_4bit_inference_naive_bf16
+            fx = (
+                cgemm_4bit_inference_naive_fp16
+                if dtype == torch_float16
+                else cgemm_4bit_inference_naive_bf16
+            )
 
             blocksize = ctypes_c_int32(blocksize)
-            fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
-               lda, ldb, ldc, blocksize, CUDA_STREAM,)
-        pass
+            fx(
+                m,
+                n,
+                k,
+                get_ptr(X),
+                get_ptr(W),
+                get_ptr(absmax),
+                get_ptr(stats),
+                get_ptr(out),
+                lda,
+                ldb,
+                ldc,
+                blocksize,
+                CUDA_STREAM,
+            )
 
         return out
-    pass
+
 else:
+
     def fast_gemv(X, W, quant_state, out = None):
-        if quant_state is None: return torch_matmul(X, W, out = out)
+        if quant_state is None:
+            return torch_matmul(X, W, out = out)
         # For fast X @ W where seq_len == 1
         # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
         _, q_len, hd = X.shape
@@ -661,27 +834,36 @@ def fast_gemv(X, W, quant_state, out = None):
 
         if type(quant_state) is not list:
             # https://github.com/TimDettmers/bitsandbytes/pull/763/files
-            absmax     = quant_state.absmax
-            shape      = quant_state.shape
-            dtype      = quant_state.dtype
-            blocksize  = quant_state.blocksize
-            stats      = quant_state.code
-            offset     = quant_state.offset
-            state2     = quant_state.state2
-            absmax2    = state2.absmax
-            code2      = state2.code
+            absmax = quant_state.absmax
+            shape = quant_state.shape
+            dtype = quant_state.dtype
+            blocksize = quant_state.blocksize
+            stats = quant_state.code
+            offset = quant_state.offset
+            state2 = quant_state.state2
+            absmax2 = state2.absmax
+            code2 = state2.code
             blocksize2 = state2.blocksize
         else:
-            absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
+            absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (
+                quant_state
+            )
             offset, state2 = compressed_stats
             absmax2, code2, blocksize2, _, _, _, _ = state2
-        pass
         # assert(dtype == X.dtype)
         bout = shape[0]
         device = W.device
 
         if out is None:
-            out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
+            out = torch_empty(
+                (
+                    1,
+                    1,
+                    bout,
+                ),
+                dtype = dtype,
+                device = device,
+            )
         # else:
         #     assert(out.shape == (1, 1, bout,))
         # pass
@@ -691,7 +873,7 @@ def fast_gemv(X, W, quant_state, out = None):
         k = shape[1]
         lda = shape[0]
         ldc = shape[0]
-        ldb = (hd+1)//2
+        ldb = (hd + 1) // 2
         m = ctypes_c_int32(m)
         n = ctypes_c_int32(n)
         k = ctypes_c_int32(k)
@@ -701,29 +883,46 @@ def fast_gemv(X, W, quant_state, out = None):
 
         df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
         cdequantize_blockwise_fp32(
-            get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
-            ctypes_c_int(blocksize2), ctypes_c_int(df.numel()),
+            get_ptr(code2),
+            get_ptr(absmax),
+            get_ptr(absmax2),
+            get_ptr(df),
+            ctypes_c_int(blocksize2),
+            ctypes_c_int(df.numel()),
         )
         df += offset
         absmax = df
 
-        fx = cgemm_4bit_inference_naive_fp16 if dtype == torch_float16 else \
-             cgemm_4bit_inference_naive_bf16
+        fx = (
+            cgemm_4bit_inference_naive_fp16
+            if dtype == torch_float16
+            else cgemm_4bit_inference_naive_bf16
+        )
 
         blocksize = ctypes_c_int32(blocksize)
-        fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
-           lda, ldb, ldc, blocksize,)
+        fx(
+            m,
+            n,
+            k,
+            get_ptr(X),
+            get_ptr(W),
+            get_ptr(absmax),
+            get_ptr(stats),
+            get_ptr(out),
+            lda,
+            ldb,
+            ldc,
+            blocksize,
+        )
 
         return out
-    pass
-pass
 
 
 def fast_linear_forward(proj, X, temp_lora = None, out = None):
-
     W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
     bsz, q_len, in_dim = X.shape
-    if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
+    if q_len != 1:
+        return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
 
     if W_quant is None:
         out = torch_matmul(X, W.t(), out = out)
@@ -734,7 +933,6 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
     else:
         W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
         out = torch_matmul(X, W, out = out)
-    pass
 
     # Add in LoRA weights
     if lora_A is not None:
@@ -744,7 +942,6 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
         if not hasattr(lora_A, "_fast_lora"):
             lora_A._fast_lora = lora_A.to(dtype)
             lora_B._fast_lora = lora_B.to(dtype)
-        pass
 
         if bsz == 1:
             out = out.view(out_dim)
@@ -752,16 +949,16 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
             out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
         else:
             out = out.view(bsz, out_dim)
-            temp_lora = torch_mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
+            temp_lora = torch_mm(
+                X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora
+            )
             out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
-        pass
         out = out.view(bsz, 1, out_dim)
-    pass
 
-    if bias is not None: out += bias
+    if bias is not None:
+        out += bias
 
     return out
-pass
 
 
 def matmul_lora(X, W, W_quant, A, B, s, out = None):
@@ -773,14 +970,14 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
         reshape = True
     else:
         reshape = False
-    pass
 
     if W.dtype == torch.float8_e4m3fn:
         out = fp8_linear(X, W, W_quant)
     else:
         W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
         out = torch_matmul(X, W, out = out)
-    if W_quant is not None: del W
+    if W_quant is not None:
+        del W
 
     if A is not None:
         # LoRA is enabled
@@ -788,7 +985,5 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
         XA = torch_matmul(X, A.to(dtype))
         out.addmm_(XA, B.to(dtype), alpha = s)
         # out += (X @ A.to(dtype)) @ (s * B.to(dtype))
-    pass
 
     return out.view(batch, seq_len, -1) if reshape else out
-pass
diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py
index 89d3fd763..e8fc4549d 100644
--- a/unsloth/models/__init__.py
+++ b/unsloth/models/__init__.py
@@ -12,18 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .llama     import FastLlamaModel
-from .loader    import FastLanguageModel, FastVisionModel, FastTextModel, FastModel
-from .mistral   import FastMistralModel
-from .qwen2     import FastQwen2Model
-from .qwen3     import FastQwen3Model
+from .llama import FastLlamaModel
+from .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel
+from .mistral import FastMistralModel
+from .qwen2 import FastQwen2Model
+from .qwen3 import FastQwen3Model
 from .qwen3_moe import FastQwen3MoeModel
-from .granite   import FastGraniteModel
+from .granite import FastGraniteModel
+
 try:
     from .falcon_h1 import FastFalconH1Model
 except:
     # transformers_version < 4.53.0 does not have falcon_h1 so silenty skip it for now
     pass
-from .dpo       import PatchDPOTrainer, PatchKTOTrainer
+from .dpo import PatchDPOTrainer, PatchKTOTrainer
 from ._utils import is_bfloat16_supported, is_vLLM_available, __version__
-from .rl        import PatchFastRL, vLLMSamplingParams
\ No newline at end of file
+from .rl import PatchFastRL, vLLMSamplingParams
diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py
index f65b7e9b2..29c3b2f99 100644
--- a/unsloth/models/_utils.py
+++ b/unsloth/models/_utils.py
@@ -18,7 +18,6 @@
     "SUPPORTS_BFLOAT16",
     "is_bfloat16_supported",
     "is_vLLM_available",
-
     "prepare_model_for_kbit_training",
     "xformers",
     "xformers_attention",
@@ -50,19 +49,16 @@
     "patch_layernorm",
     "patch_torch_compile",
     "patch_model_and_tokenizer",
-
     "patch_unsloth_gradient_checkpointing",
     "unpatch_unsloth_gradient_checkpointing",
     "patch_gradient_checkpointing",
     "unpatch_gradient_checkpointing",
-
     "HAS_CUT_CROSS_ENTROPY",
     "EMPTY_LOGITS",
     "fused_linear_cross_entropy",
     "unsloth_fused_ce_loss",
     "patch_unsloth_smart_gradient_checkpointing",
     "unpatch_unsloth_smart_gradient_checkpointing",
-
     "patch_compiled_autograd",
     "process_vision_info",
     "unsloth_compile_transformers",
@@ -79,6 +75,7 @@
 import torch
 from typing import Union, Optional, List, Any, Callable, Tuple
 from platform import system as platform_system
+
 platform_system = platform_system()
 import numpy as np
 import contextlib
@@ -118,12 +115,10 @@
     unsloth_offloaded_gradient_checkpoint,
     patch_unsloth_gradient_checkpointing,
     unpatch_unsloth_gradient_checkpointing,
-
     Unsloth_Gradient_Checkpointer,
     unsloth_gradient_checkpoint,
     patch_gradient_checkpointing,
     unpatch_gradient_checkpointing,
-
     patch_unsloth_smart_gradient_checkpointing,
     unpatch_unsloth_smart_gradient_checkpointing,
 )
@@ -146,51 +141,65 @@
 from unsloth_zoo.temporary_patches import (
     TEMPORARY_PATCHES,
 )
+
 for temporary_patch in TEMPORARY_PATCHES:
     temporary_patch()
 
 # =============================================
 # Disable some warnings which can get annoying
-warnings.filterwarnings(action = "ignore", category = UserWarning,    module = "torch")
-warnings.filterwarnings(action = "ignore", category = FutureWarning,  module = "torch")
-warnings.filterwarnings(action = "ignore", category = UserWarning,    module = "huggingface_hub")
-warnings.filterwarnings(action = "ignore", category = FutureWarning,  module = "huggingface_hub")
-warnings.filterwarnings(action = "ignore", category = UserWarning,    module = "trl")
-warnings.filterwarnings(action = "ignore", category = FutureWarning,  module = "trl")
-warnings.filterwarnings(action = "ignore", category = FutureWarning,  module = "xformers")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "torch")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
+warnings.filterwarnings(
+    action = "ignore", category = FutureWarning, module = "huggingface_hub"
+)
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers")
 warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
-warnings.filterwarnings(action = "ignore", category = UserWarning,    module = "transformers")
-warnings.filterwarnings(action = "ignore", category = FutureWarning,  module = "accelerate")
-warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
+warnings.filterwarnings(
+    action = "ignore", category = RuntimeWarning, module = "multiprocessing"
+)
 warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
-warnings.filterwarnings(action = "ignore", category = UserWarning,    module = "triton")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "triton")
 # Stop "Special tokens have been added in the vocabulary, ..."
 import logging
-logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)
+
+logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL + 1)
+
 
 # Ignore logging messages
 class HideLoggingMessage(logging.Filter):
-    __slots__ = "text",
-    def __init__(self, text): self.text = text
-    def filter(self, x): return not (self.text in x.getMessage())
-pass
+    __slots__ = ("text",)
+
+    def __init__(self, text):
+        self.text = text
+
+    def filter(self, x):
+        return not (self.text in x.getMessage())
+
 
 # Stop vLLM messages
-if os.environ.get('UNSLOTH_ENABLE_LOGGING', '0') != '1':
+if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") != "1":
     try:
         from vllm.worker.worker import logger as vllm_worker_logger
+
         vllm_worker_logger.addFilter(HideLoggingMessage("Sleep mode freed"))
         del vllm_worker_logger
     except:
         pass
     try:
         from vllm.v1.worker.gpu_worker import logger as vllm_gpu_worker_logger
+
         vllm_gpu_worker_logger.addFilter(HideLoggingMessage("Sleep mode freed"))
         del vllm_gpu_worker_logger
     except:
         pass
     try:
         from vllm.executor.executor_base import logger as vllm_executor_logger
+
         vllm_executor_logger.addFilter(HideLoggingMessage("to fall asleep"))
         vllm_executor_logger.addFilter(HideLoggingMessage("to wake up"))
         vllm_executor_logger.addFilter(HideLoggingMessage("Executor is not sleeping"))
@@ -198,48 +207,66 @@ def filter(self, x): return not (self.text in x.getMessage())
     except:
         pass
     try:
-        from vllm.core.block.prefix_caching_block import logger as vllm_prefix_caching_logger
+        from vllm.core.block.prefix_caching_block import (
+            logger as vllm_prefix_caching_logger,
+        )
+
         vllm_prefix_caching_logger.addFilter(HideLoggingMessage("reset prefix cache"))
         del vllm_prefix_caching_logger
     except:
         pass
     try:
         from vllm.v1.core.block_pool import logger as vllm_block_pool_logger
+
         vllm_block_pool_logger.addFilter(HideLoggingMessage("reset prefix cache"))
         del vllm_block_pool_logger
     except:
         pass
     try:
         from vllm.lora.models import logger as vllm_lora_model_logger
-        vllm_lora_model_logger.addFilter(HideLoggingMessage("Regarding multimodal models, vLLM currently only supports adding"))
+
+        vllm_lora_model_logger.addFilter(
+            HideLoggingMessage(
+                "Regarding multimodal models, vLLM currently only supports adding"
+            )
+        )
         del vllm_lora_model_logger
     except:
         pass
     try:
-        from vllm.attention.utils.fa_utils import logger as vllm_attention_utils_fa_utils_logger
-        vllm_attention_utils_fa_utils_logger.addFilter(HideLoggingMessage("Cannot use FA version"))
+        from vllm.attention.utils.fa_utils import (
+            logger as vllm_attention_utils_fa_utils_logger,
+        )
+
+        vllm_attention_utils_fa_utils_logger.addFilter(
+            HideLoggingMessage("Cannot use FA version")
+        )
         del vllm_attention_utils_fa_utils_logger
     except:
         pass
-pass
 
 # The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here.
 from transformers.training_args import logger as transformers_training_args_logger
+
 transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups"))
 # torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED.
 transformers_training_args_logger.addFilter(HideLoggingMessage("torch.distributed"))
 # average_tokens_across_devices is set to True but it is invalid when world size is1
-transformers_training_args_logger.addFilter(HideLoggingMessage("average_tokens_across_devices"))
+transformers_training_args_logger.addFilter(
+    HideLoggingMessage("average_tokens_across_devices")
+)
 del transformers_training_args_logger
 
 # No label_names provided for model class
 from transformers.trainer import logger as transformers_trainer_logger
+
 transformers_trainer_logger.addFilter(HideLoggingMessage("No label_names"))
 del transformers_trainer_logger
 
 # Using the default loss: `ForCausalLMLoss`.
 try:
     from transformers.modeling_utils import logger as transformers_modeling_utils_logger
+
     transformers_modeling_utils_logger.addFilter(HideLoggingMessage("ForCausalLMLoss"))
     del transformers_modeling_utils_logger
 except:
@@ -248,15 +275,23 @@ def filter(self, x): return not (self.text in x.getMessage())
 # The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
 try:
     from accelerate.utils.modeling import logger as accelerate_utils_modeling_logger
-    accelerate_utils_modeling_logger.addFilter(HideLoggingMessage("The model weights are not tied"))
+
+    accelerate_utils_modeling_logger.addFilter(
+        HideLoggingMessage("The model weights are not tied")
+    )
     del accelerate_utils_modeling_logger
 except:
     pass
 
 # Setting `pad_token_id` to `eos_token_id`
 try:
-    from transformers.generation.utils import logger as transformers_generation_utils_logger
-    transformers_generation_utils_logger.addFilter(HideLoggingMessage("Setting `pad_token_id` to `eos_token_id`"))
+    from transformers.generation.utils import (
+        logger as transformers_generation_utils_logger,
+    )
+
+    transformers_generation_utils_logger.addFilter(
+        HideLoggingMessage("Setting `pad_token_id` to `eos_token_id`")
+    )
     # "You have set `compile_config`
     transformers_generation_utils_logger.addFilter(HideLoggingMessage("compile_config"))
     del transformers_generation_utils_logger
@@ -265,7 +300,10 @@ def filter(self, x): return not (self.text in x.getMessage())
 
 # The following generation flags are not valid and may be ignored:
 try:
-    from transformers.generation.configuration_utils import logger as configuration_logger
+    from transformers.generation.configuration_utils import (
+        logger as configuration_logger,
+    )
+
     configuration_logger.addFilter(HideLoggingMessage("following generation flags"))
     del configuration_logger
 except:
@@ -274,6 +312,7 @@ def filter(self, x): return not (self.text in x.getMessage())
 # Gemma3 It is strongly recommended to train Gemma3 models with the `eager`
 try:
     from transformers.models.gemma3.modeling_gemma3 import logger as gemma3_logger
+
     gemma3_logger.addFilter(HideLoggingMessage("strongly recommended"))
     del gemma3_logger
 except:
@@ -282,6 +321,7 @@ def filter(self, x): return not (self.text in x.getMessage())
 # Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed.
 try:
     from huggingface_hub.file_download import logger as hub_logger
+
     hub_logger.addFilter(HideLoggingMessage("hf_xet"))
     del hub_logger
 except:
@@ -290,6 +330,7 @@ def filter(self, x): return not (self.text in x.getMessage())
 # MXFP4 quantization requires triton >= 3.4.0
 try:
     from transformers.quantizers.quantizer_mxfp4 import logger as mxfp4_logger
+
     mxfp4_logger.addFilter(HideLoggingMessage("requires triton"))
     del mxfp4_logger
 except:
@@ -321,6 +362,7 @@ def filter(self, x): return not (self.text in x.getMessage())
 # Using a slow image processor as `use_fast`
 try:
     from transformers.processing_utils import logger as processing_utils_logger
+
     processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
     del processing_utils_logger
 except:
@@ -328,7 +370,10 @@ def filter(self, x): return not (self.text in x.getMessage())
 
 # Using a slow image processor as `use_fast`
 try:
-    from transformers.models.auto.image_processing_auto import logger as processing_utils_logger
+    from transformers.models.auto.image_processing_auto import (
+        logger as processing_utils_logger,
+    )
+
     processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
     del processing_utils_logger
 except:
@@ -337,6 +382,7 @@ def filter(self, x): return not (self.text in x.getMessage())
 # `use_cache=True` is incompatible with gradient checkpointing
 try:
     from transformers.trainer import logger as trainer_logger
+
     trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
     del trainer_logger
 except:
@@ -345,6 +391,7 @@ def filter(self, x): return not (self.text in x.getMessage())
 # `use_cache=True` is incompatible with gradient checkpointing
 try:
     from transformers.utils.generic import logger as trainer_logger
+
     trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
     del trainer_logger
 except:
@@ -353,6 +400,7 @@ def filter(self, x): return not (self.text in x.getMessage())
 # We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')
 try:
     from transformers.modeling_utils import logger as modeling_utils_logger
+
     modeling_utils_logger.addFilter(HideLoggingMessage("anti-pattern"))
     del modeling_utils_logger
 except:
@@ -361,35 +409,43 @@ def filter(self, x): return not (self.text in x.getMessage())
 # Errors out on
 # Some weights of Gemma3nForConditionalGeneration were not initialized from the model checkpoint
 from transformers.modeling_utils import logger as transformers_logger
+
+
 class _RaiseUninitialized(logging.Handler):
     def __init__(self):
         super().__init__()
+
     def emit(self, record):
         record_lower = str(record).lower()
-        if ("some weights of" in record_lower) and \
-            ("score.weight" not in record_lower) and \
-            ("classifier.weight" not in record_lower) and \
-            ("cls.predictions" not in record_lower) and \
-            ("predictions.decoder" not in record_lower) and \
-            (os.environ.get("UNSLOTH_WARN_UNINITIALIZED", "1") == "1"):
+        if (
+            ("some weights of" in record_lower)
+            and ("score.weight" not in record_lower)
+            and ("classifier.weight" not in record_lower)
+            and ("cls.predictions" not in record_lower)
+            and ("predictions.decoder" not in record_lower)
+            and (os.environ.get("UNSLOTH_WARN_UNINITIALIZED", "1") == "1")
+        ):
             raise Exception(
-                f"Unsloth: Critical error since some weights are not initialized.\n"\
-                f"Please try updating Unsloth, transformers and timm via:\n"\
-                f"`pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo transformers timm`\n"\
+                f"Unsloth: Critical error since some weights are not initialized.\n"
+                f"Please try updating Unsloth, transformers and timm via:\n"
+                f"`pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo transformers timm`\n"
                 f"{str(record)}"
             )
-pass
+
+
 class RaiseUninitialized:
     def __init__(self):
         self.error_handler = _RaiseUninitialized()
         transformers_logger.addHandler(self.error_handler)
+
     def remove(self):
         transformers_logger.removeHandler(self.error_handler)
-pass
+
 
 # Patch get_model_param_count to record correct 4bit / 8bit
 from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled
 
+
 def extract_quant_model_param_count(model):
     """
     Calculate quant model param count based on difference in param class. Returns int for param count.
@@ -401,52 +457,66 @@ def extract_quant_model_param_count(model):
         else:
             count += p.numel()
     return count
-pass
+
 
 def get_model_param_count(model, trainable_only = False):
     """
     Calculate model's total param count. If trainable_only is True then count only those requiring grads
     """
     if is_deepspeed_zero3_enabled():
+
         def numel(p):
             return p.ds_numel if hasattr(p, "ds_numel") else p.numel()
     else:
+
         def numel(p):
             return p.numel()
-    s = sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
-    if (not trainable_only) and \
-        hasattr(model, "config") and \
-        hasattr(model.config, "quantization_config"):
+
+    s = sum(
+        numel(p) for p in model.parameters() if not trainable_only or p.requires_grad
+    )
+    if (
+        (not trainable_only)
+        and hasattr(model, "config")
+        and hasattr(model.config, "quantization_config")
+    ):
         approx = extract_quant_model_param_count(model)
         if approx is not None:
             s = approx
     return s
-pass
+
+
 import transformers.trainer_pt_utils
+
 transformers.trainer_pt_utils.get_model_param_count = get_model_param_count
 import transformers.trainer
+
 transformers.trainer.get_model_param_count = get_model_param_count
 # =============================================
 
 # =============================================
 # Edits all Config files to enable RoPE Scaling for all models
 
+
 # Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
 def patch_mistral_nemo_config(config):
     if "head_dim (" not in config:
-        add_head_dim = "If it is not specified, will default to `8`.\n"\
-            "        head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"\
+        add_head_dim = (
+            "If it is not specified, will default to `8`.\n"
+            "        head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"
             "            The attention head dimension."
-        config = config.replace("If it is not specified, will default to `8`.", add_head_dim)
+        )
+        config = config.replace(
+            "If it is not specified, will default to `8`.", add_head_dim
+        )
 
         add_head_dim = "num_key_value_heads=8,\n        head_dim=None,"
         config = config.replace("num_key_value_heads=8,", add_head_dim)
 
         add_head_dim = "self.sliding_window = sliding_window\n        self.head_dim = head_dim or hidden_size // num_attention_heads\n"
         config = config.replace("self.sliding_window = sliding_window", add_head_dim)
-    pass
     return config
-pass
+
 
 try:
     # Some Config files use layer_type_validation
@@ -461,12 +531,22 @@ def patch_mistral_nemo_config(config):
 except:
     from transformers import PretrainedConfig
 
-model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2", "granite", "qwen3", "qwen3_moe", "falcon_h1"]
+model_architectures = [
+    "llama",
+    "mistral",
+    "gemma",
+    "gemma2",
+    "qwen2",
+    "granite",
+    "qwen3",
+    "qwen3_moe",
+    "falcon_h1",
+]
 
 for model_name in model_architectures:
     config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"
     model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
-    config_filename = f"{model_name.title().replace('_','')}Config" # qwen3 arch folder is qwen3_moe but config is Qwen3Config. Need to remove underscore(_) for now
+    config_filename = f"{model_name.title().replace('_', '')}Config"  # qwen3 arch folder is qwen3_moe but config is Qwen3Config. Need to remove underscore(_) for now
     try:
         exec(f"from {config_filepath} import {config_filename}", globals())
     except:
@@ -476,11 +556,12 @@ def patch_mistral_nemo_config(config):
         config = inspect.getsource(eval(config_filename))
     except:
         continue
-    if "rope_scaling" in config: continue
+    if "rope_scaling" in config:
+        continue
     config = re.sub(
         r"(\*\*kwargs)[\s]{0,}\,[\s]{0,}\)[\s]{0,}\:",
-        r"rope_scaling=None,"\
-        r"\n        **kwargs):\n"\
+        r"rope_scaling=None,"
+        r"\n        **kwargs):\n"
         r"\n        self.rope_scaling = rope_scaling\n",
         config,
     )
@@ -489,12 +570,10 @@ def patch_mistral_nemo_config(config):
     if model_name == "mistral":
         if Version(transformers_version) <= Version("4.42.4"):
             config = patch_mistral_nemo_config(config)
-    pass
 
     exec(config, globals())
     exec(f"import {config_filepath}", globals())
     exec(f"{config_filepath}.{config_filename} = {config_filename}", globals())
-pass
 # =============================================
 
 # =============================================
@@ -507,7 +586,6 @@ def patch_mistral_nemo_config(config):
     else:
         torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
         torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
-    pass
 elif DEVICE_TYPE == "xpu":
     if Version(torch_version) < Version("2.6.0"):
         raise RuntimeError("torch.xpu currently only supports torch.version >= 2.6.0")
@@ -541,16 +619,18 @@ def patch_mistral_nemo_config(config):
 # =============================================
 # Weird Databricks errors
 from transformers.utils import is_openai_available
+
 if is_openai_available():
     try:
         from openai import OpenAI
     except:
         print("Unsloth: OpenAI failed to import - ignoring for now.")
         import transformers.utils
-        def _is_openai_available(): return False
+
+        def _is_openai_available():
+            return False
+
         transformers.utils.is_openai_available = _is_openai_available
-    pass
-pass
 
 # =============================================
 # Get Flash Attention v2 if Ampere (RTX 30xx, A100)
@@ -581,37 +661,44 @@ def _is_openai_available(): return False
 
                 # Also check for softcapping
                 from flash_attn import __version__ as flash_attn_version
-                HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3")
+
+                HAS_FLASH_ATTENTION_SOFTCAPPING = Version(
+                    flash_attn_version
+                ) >= Version("2.6.3")
                 if not HAS_FLASH_ATTENTION_SOFTCAPPING:
                     print(
-                        "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
-                        "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
-                        "To update flash-attn, do the below:\n"\
+                        "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"
+                        "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"
+                        "To update flash-attn, do the below:\n"
                         '\npip install --no-deps --no-build-isolation --upgrade "flash-attn>=2.6.3"'
                     )
             except:
                 print(
-                    "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
-                    "A possible explanation is you have a new CUDA version which isn't\n"\
-                    "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\
-                    "We shall now use Xformers instead, which does not have any performance hits!\n"\
+                    "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"
+                    "A possible explanation is you have a new CUDA version which isn't\n"
+                    "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"
+                    "We shall now use Xformers instead, which does not have any performance hits!\n"
                     "We found this negligible impact by benchmarking on 1x A100."
                 )
 
                 # Stop Flash Attention from importing!
                 import transformers.utils.import_utils
-                transformers.utils.import_utils.is_flash_attn_2_available = lambda *args, **kwargs: False
+
+                transformers.utils.import_utils.is_flash_attn_2_available = (
+                    lambda *args, **kwargs: False
+                )
                 import transformers.utils
-                transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False
+
+                transformers.utils.is_flash_attn_2_available = (
+                    lambda *args, **kwargs: False
+                )
 
                 HAS_FLASH_ATTENTION = False
-            pass
         else:
             HAS_FLASH_ATTENTION = False
     else:
         # Tri Dao's benchmark shows xformers is faster for now.
         HAS_FLASH_ATTENTION = False
-    pass
 elif DEVICE_TYPE == "hip":
     SUPPORTS_BFLOAT16 = True
     if _is_package_available("flash_attn"):
@@ -626,27 +713,34 @@ def _is_openai_available(): return False
 
             # Also check for softcapping
             from flash_attn import __version__ as flash_attn_version
-            HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3")
+
+            HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version(
+                "2.6.3"
+            )
             if not HAS_FLASH_ATTENTION_SOFTCAPPING:
                 print(
-                    "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
-                    "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
-                    "To update flash-attn, do the below:\n"\
+                    "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"
+                    "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"
+                    "To update flash-attn, do the below:\n"
                     '\npip install --no-deps --no-build-isolation --upgrade "flash-attn>=2.6.3"'
                 )
         except:
             print(
-                "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
-                "A possible explanation is you have a new CUDA version which isn't\n"\
-                "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\
-                "We shall now use Xformers instead, which does not have any performance hits!\n"\
+                "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"
+                "A possible explanation is you have a new CUDA version which isn't\n"
+                "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"
+                "We shall now use Xformers instead, which does not have any performance hits!\n"
                 "We found this negligible impact by benchmarking on 1x A100."
             )
 
             # Stop Flash Attention from importing!
             import transformers.utils.import_utils
-            transformers.utils.import_utils.is_flash_attn_2_available = lambda *args, **kwargs: False
+
+            transformers.utils.import_utils.is_flash_attn_2_available = (
+                lambda *args, **kwargs: False
+            )
             import transformers.utils
+
             transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False
 
             HAS_FLASH_ATTENTION = False
@@ -657,97 +751,103 @@ def _is_openai_available(): return False
 # Get Xformers
 try:
     from xformers import __version__ as xformers_version
+
     # [TODO] Xformers does NOT work on RTX 50x (12), B200 (10), Jetson (11)
     # See https://github.com/facebookresearch/xformers/issues/1329
     # CUDA error (/workspace/xfrm2/third_party/flash-attention/hopper/flash_fwd_launch_template.h:188)
     major_version, minor_version = torch.cuda.get_device_capability()
-    if (
-        (f"{major_version}.{minor_version}" in ("10.0", "11.0", "12.0")) and \
-        (Version(xformers_version) in (Version("0.0.32.post2"),))
+    if (f"{major_version}.{minor_version}" in ("10.0", "11.0", "12.0")) and (
+        Version(xformers_version) in (Version("0.0.32.post2"),)
     ):
         raise NotImplementedError(
-            "Unsloth: Xformers does not work in RTX 50X, Blackwell GPUs as of yet. Please build from source via\n"\
-            "```\n"\
-            "pip install ninja\n"\
-            "pip install -v --no-build-isolation -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers\n"\
+            "Unsloth: Xformers does not work in RTX 50X, Blackwell GPUs as of yet. Please build from source via\n"
+            "```\n"
+            "pip install ninja\n"
+            "pip install -v --no-build-isolation -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers\n"
             "```\n"
         )
     # Temporarily disable 0.0.27 and higher - inference issues
-    if False: #Version(xformers_version) >= Version("0.0.27"):
+    if False:  # Version(xformers_version) >= Version("0.0.27"):
         raise ImportError(
-            "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
-            "then press Disconnect Runtime and then Restart it.\n"\
-            "\n"\
+            "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "
+            "then press Disconnect Runtime and then Restart it.\n"
+            "\n"
             "%%capture\n"
             "# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
             '!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
-            '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\
-            '\n'\
-            f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"\
+            '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'
+            "\n"
+            f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"
             'Please downgrade xformers via `pip install --force-reinstall "xformers<=0.0.27"'
         )
-    pass
 
-    if   Version(torch_version) < Version("2.2.0") and Version(xformers_version) >= Version("0.0.24"):
+    if Version(torch_version) < Version("2.2.0") and Version(
+        xformers_version
+    ) >= Version("0.0.24"):
         raise ImportError(
-            f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
+            f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
             f"Please install xformers < 0.0.24 for torch = {torch_version}."
         )
-    elif Version(torch_version) < Version("2.3.0") and Version(xformers_version) >= Version("0.0.26"):
+    elif Version(torch_version) < Version("2.3.0") and Version(
+        xformers_version
+    ) >= Version("0.0.26"):
         raise ImportError(
-            f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
+            f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
             f"Please install xformers < 0.0.26 for torch = {torch_version}."
         )
-    elif Version(torch_version) < Version("2.4.0") and Version(xformers_version) > Version("0.0.27"):
+    elif Version(torch_version) < Version("2.4.0") and Version(
+        xformers_version
+    ) > Version("0.0.27"):
         raise ImportError(
-            f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
+            f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
             f"Please install xformers <= 0.0.27 for torch = {torch_version}."
         )
-    pass
 
     from xformers._cpp_lib import _register_extensions
+
     try:
-        _register_extensions() # Check if C++ modules are loaded correctly
+        _register_extensions()  # Check if C++ modules are loaded correctly
     except Exception as error:
         raise ImportError(
-            "Unsloth: Xformers was not installed correctly.\n"\
-            "Please install xformers separately first.\n"\
-            "Then confirm if it's correctly installed by running:\n"\
+            "Unsloth: Xformers was not installed correctly.\n"
+            "Please install xformers separately first.\n"
+            "Then confirm if it's correctly installed by running:\n"
             "python -m xformers.info\n\n"
             "Longer error message:\n" + str(error)
         )
-    pass
     import xformers.ops.fmha as xformers
+
     xformers_attention = xformers.memory_efficient_attention
 except ModuleNotFoundError:
     xformers = None
     xformers_attention = None
     xformers_version = None
 except Exception as e:
-    print("========\nSwitching to PyTorch attention since your Xformers is broken.\n========\n")
+    print(
+        "========\nSwitching to PyTorch attention since your Xformers is broken.\n========\n"
+    )
     print(str(e))
     xformers = None
     xformers_attention = None
     xformers_version = None
-pass
 
 # Check TRL version
 from trl import __version__ as trl_version
+
 # Unsloth now supports all TRL versions!
-if False:#Version(trl_version) >= Version("0.9.0"):
+if False:  # Version(trl_version) >= Version("0.9.0"):
     raise ImportError(
-        "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
-        "then press Disconnect Runtime and then Restart it.\n"\
-        "\n"\
+        "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "
+        "then press Disconnect Runtime and then Restart it.\n"
+        "\n"
         "%%capture\n"
         "# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
         '!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
-        '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\
-        '\n'\
-        f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"\
-        'Please downgrade TRL via `pip install --force-reinstall trl'
+        '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'
+        "\n"
+        f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"
+        "Please downgrade TRL via `pip install --force-reinstall trl"
     )
-pass
 
 # =============================================
 # Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout'
@@ -773,30 +873,43 @@ def _is_openai_available(): return False
 
 # Transformers 4.46 breaks dynamic caching. This is a hack
 import transformers.generation.configuration_utils
+
 if hasattr(transformers.generation.configuration_utils, "ALL_CACHE_IMPLEMENTATIONS"):
-    if type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS) is list:
-        if "dynamic" not in transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS:
-            transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("dynamic")
-    pass
-pass
+    if (
+        type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS)
+        is list
+    ):
+        if (
+            "dynamic"
+            not in transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS
+        ):
+            transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append(
+                "dynamic"
+            )
 # =============================================
 
 # =============================================
 # Torch compile settings
-UNSLOTH_COMPILE_DEBUG         = os.environ.get("UNSLOTH_COMPILE_DEBUG",         "0") == "1"
-UNSLOTH_COMPILE_MAXIMUM       = os.environ.get("UNSLOTH_COMPILE_MAXIMUM",       "0") == "1"
-UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "1") == "1"
+UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1"
+UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1"
+UNSLOTH_COMPILE_IGNORE_ERRORS = (
+    os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "1") == "1"
+)
 # Just remove max_autotune_gemm warning
 from torch._inductor.runtime.hints import DeviceProperties
 
+
 @functools.lru_cache(None)
 def is_big_gpu(index) -> bool:
-
     if DEVICE_TYPE == "xpu":
-        prop = DeviceProperties.create(torch.device("xpu", index) if type(index) is int else index)
+        prop = DeviceProperties.create(
+            torch.device("xpu", index) if type(index) is int else index
+        )
         min_sms = 16
     else:
-        prop = DeviceProperties.create(torch.device("cuda", index) if type(index) is int else index)
+        prop = DeviceProperties.create(
+            torch.device("cuda", index) if type(index) is int else index
+        )
         min_sms = 80
 
     avail_sms = prop.multi_processor_count
@@ -804,7 +917,9 @@ def is_big_gpu(index) -> bool:
         return False
     return True
 
+
 import torch._inductor.utils
+
 torch._inductor.utils.is_big_gpu = is_big_gpu
 patch_torch_compile(
     debug = UNSLOTH_COMPILE_DEBUG,
@@ -813,77 +928,101 @@ def is_big_gpu(index) -> bool:
 )
 
 torch_compile_options = {
-    "epilogue_fusion"   : True,
-    "max_autotune"      : True,
-    "shape_padding"     : True,
-    "trace.enabled"     : UNSLOTH_COMPILE_DEBUG,
-    "triton.cudagraphs" : False,
+    "epilogue_fusion": True,
+    "max_autotune": True,
+    "shape_padding": True,
+    "trace.enabled": UNSLOTH_COMPILE_DEBUG,
+    "triton.cudagraphs": False,
 }
 
 import accelerate
+
+
 def torch_compile_kwargs(*args, **kwargs):
     print("Unsloth: Enabled auto compiling")
-    return {"dynamic" : True, "fullgraph" : False, "options" : torch_compile_options,}
-pass
+    return {
+        "dynamic": True,
+        "fullgraph": False,
+        "options": torch_compile_options,
+    }
+
 
 accelerate.utils.dataclasses.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
-accelerate.utils.TorchDynamoPlugin.to_kwargs             = torch_compile_kwargs
-accelerate.accelerator.TorchDynamoPlugin.to_kwargs       = torch_compile_kwargs
+accelerate.utils.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
+accelerate.accelerator.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
 del accelerate
 
+
 def patch_regional_compilation():
     # Regional torch 2.5 Recompilation - weirdly very slow??
-    if torch.nn.ModuleList.__name__ == "UnslothModuleList": return
+    if torch.nn.ModuleList.__name__ == "UnslothModuleList":
+        return
     # Only works for torch 2.5
-    if Version(torch.__version__) < Version("2.5.0"): return
+    if Version(torch.__version__) < Version("2.5.0"):
+        return
 
     old_module_list = torch.nn.ModuleList
     os.environ["UNSLOTH_PATCHED"] = "1"
 
     def UnslothModuleList(*args, **kwargs):
         if len(args) == 1 and len(kwargs) == 0 and type(args[0]) is list:
-            args = [old_module_list([torch.compile(x, dynamic = True, options = torch_compile_options, fullgraph = False) for x in args[0]])]
+            args = [
+                old_module_list(
+                    [
+                        torch.compile(
+                            x,
+                            dynamic = True,
+                            options = torch_compile_options,
+                            fullgraph = False,
+                        )
+                        for x in args[0]
+                    ]
+                )
+            ]
         return old_module_list(*args, **kwargs)
-    pass
+
     UnslothModuleList.__doc__ = old_module_list.__doc__
 
     torch.nn.ModuleList = UnslothModuleList
     return
-pass
+
 
 # =============================================
 
+
 def prepare_model_for_kbit_training(
-    model                      : Any,
-    use_gradient_checkpointing : Optional = True,
-    use_reentrant              : Optional[bool] = True,
+    model: Any,
+    use_gradient_checkpointing: Optional = True,
+    use_reentrant: Optional[bool] = True,
 ) -> Any:
     return prepare_model_for_training(
-        model                      = model,
+        model = model,
         use_gradient_checkpointing = use_gradient_checkpointing,
-        use_reentrant              = use_reentrant,
-        full_finetuning            = False,
-        train_layernorms           = False,
-        train_embedding            = False,
-        train_lm_head              = False,
-        float32_mixed_precision    = True,
+        use_reentrant = use_reentrant,
+        full_finetuning = False,
+        train_layernorms = False,
+        train_embedding = False,
+        train_lm_head = False,
+        float32_mixed_precision = True,
     )
-pass
+
 
 # =============================================
 # Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
 # For mixed precision, we need it to be in float32 not float16.
 from peft import __version__ as peft_version
 from peft.utils.integrations import dequantize_module_weight
+
 if Version(peft_version) < Version("0.12.0"):
     from peft.tuners.lora.layer import LoraLayer
+
     try:
         source = inspect.getsource(LoraLayer.update_layer)
         text = "if weight is not None:\n"
         start = source.find(text) + len(text)
         end = source.find("self.to(weight.device)", start)
         spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
-        source = source.replace(source[start : end], spaces)
+        source = source.replace(source[start:end], spaces)
         spaces = len(re.match(r"[\s]{1,}", source).group(0))
         lines = source.split("\n")
         source = "\n".join(x[spaces:] for x in lines)
@@ -893,40 +1032,46 @@ def prepare_model_for_kbit_training(
 
         # Fix up incorrect downcasting of LoRA weights
         from peft.tuners.lora.layer import LoraLayer
+
         LoraLayer.update_layer = LoraLayer_update_layer
         from peft.tuners.lora import LoraLayer
+
         LoraLayer.update_layer = LoraLayer_update_layer
     except:
         logger.warning_once(
-            "Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"\
+            "Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"
             "Luckily, your training run will still work in the meantime!"
         )
-    pass
-pass
 
 # =============================================
 import importlib
+
 global USE_MODELSCOPE
 USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
 if USE_MODELSCOPE:
     if importlib.util.find_spec("modelscope") is None:
-        raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`')
-    pass
-pass
+        raise ImportError(
+            f"You are using the modelscope hub, please install modelscope by `pip install modelscope -U`"
+        )
 
 import socket
+
+
 @functools.lru_cache(1)
 def has_internet(host = "8.8.8.8", port = 53, timeout = 3):
-    if os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1": return False
+    if os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1":
+        return False
     try:
         socket.setdefaulttimeout(timeout)
         socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect((host, port))
         return True
     except socket.error as ex:
         return False
-pass
+
 
 import psutil
+
+
 def _get_statistics(statistics = None, force_download = True):
     # We log some basic stats about which environment is being used.
     # We simply download a README.md file from HF - all data is made public.
@@ -938,17 +1083,26 @@ def _get_statistics(statistics = None, force_download = True):
     global USE_MODELSCOPE
     USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
 
-    if statistics is not None: pass
-    elif "\nCOLAB_"  in keynames and n_cpus == 1: statistics = "colab"
-    elif "\nCOLAB_"  in keynames: statistics = "colabpro"
-    elif "\nKAGGLE_" in keynames: statistics = "kaggle"
-    elif "\nRUNPOD_" in keynames: statistics = "runpod"
-    elif "\nAWS_"    in keynames: statistics = "aws"
-    elif "\nAZURE_"  in keynames: statistics = "azure"
+    if statistics is not None:
+        pass
+    elif "\nCOLAB_" in keynames and n_cpus == 1:
+        statistics = "colab"
+    elif "\nCOLAB_" in keynames:
+        statistics = "colabpro"
+    elif "\nKAGGLE_" in keynames:
+        statistics = "kaggle"
+    elif "\nRUNPOD_" in keynames:
+        statistics = "runpod"
+    elif "\nAWS_" in keynames:
+        statistics = "aws"
+    elif "\nAZURE_" in keynames:
+        statistics = "azure"
     # elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp"
-    elif "\nINVOCATION_ID" in keynames: statistics = "lambda"
+    elif "\nINVOCATION_ID" in keynames:
+        statistics = "lambda"
     # else: statistics = "other"
     else:
+
         def try_vllm_check():
             vendor_files = (
                 "/sys/class/dmi/id/product_version",
@@ -958,44 +1112,54 @@ def try_vllm_check():
                 "/sys/class/dmi/id/sys_vendor",
             )
             from pathlib import Path
+
             for vendor_file in vendor_files:
                 path = Path(vendor_file)
                 if path.is_file():
                     file_content = path.read_text().lower()
-                    if   "amazon"                in file_content: return "aws"
-                    elif "microsoft corporation" in file_content: return "azure"
-                    elif "google"                in file_content: return "gcp"
+                    if "amazon" in file_content:
+                        return "aws"
+                    elif "microsoft corporation" in file_content:
+                        return "azure"
+                    elif "google" in file_content:
+                        return "gcp"
             return "other"
-        pass
-        try:    statistics = try_vllm_check()
-        except: statistics = "other"
-    pass
+
+        try:
+            statistics = try_vllm_check()
+        except:
+            statistics = "other"
     if statistics is not None:
         import tempfile
         from huggingface_hub import snapshot_download
         from unsloth_zoo.rl_environments import execute_with_time_limit
+
         if has_internet():
+
             @execute_with_time_limit(120)
             def stats_check():
                 with tempfile.TemporaryDirectory(ignore_cleanup_errors = True) as f:
-                    snapshot_download(f"unslothai/{statistics}", force_download = True, cache_dir = f, local_dir = f)
+                    snapshot_download(
+                        f"unslothai/{statistics}",
+                        force_download = True,
+                        cache_dir = f,
+                        local_dir = f,
+                    )
+
             try:
                 stats_check()
             except TimeoutError:
                 raise TimeoutError(
-                    "Unsloth: HuggingFace seems to be down after trying for 120 seconds :(\n"\
-                    "Check https://status.huggingface.co/ for more details.\n"\
-                    "As a temporary measure, use modelscope with the same model name ie:\n"\
-                    "```\n"\
-                    "pip install modelscope\n"\
-                    "import os; os.environ['UNSLOTH_USE_MODELSCOPE'] = '1'\n"\
-                    "from unsloth import FastLanguageModel\n"\
-                    "model = FastLanguageModel.from_pretrained('unsloth/gpt-oss-20b')\n"\
+                    "Unsloth: HuggingFace seems to be down after trying for 120 seconds :(\n"
+                    "Check https://status.huggingface.co/ for more details.\n"
+                    "As a temporary measure, use modelscope with the same model name ie:\n"
+                    "```\n"
+                    "pip install modelscope\n"
+                    "import os; os.environ['UNSLOTH_USE_MODELSCOPE'] = '1'\n"
+                    "from unsloth import FastLanguageModel\n"
+                    "model = FastLanguageModel.from_pretrained('unsloth/gpt-oss-20b')\n"
                     "```"
                 )
-        pass
-    pass
-pass
 
 
 def get_statistics(local_files_only = False):
@@ -1005,35 +1169,58 @@ def get_statistics(local_files_only = False):
     # This is simply so we can check if some envs are broken or not.
     # You can disable this by setting UNSLOTH_DISABLE_STATISTICS
     import os
-    if "UNSLOTH_DISABLE_STATISTICS" in os.environ: return
-    if local_files_only: return
-    from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
+
+    if "UNSLOTH_DISABLE_STATISTICS" in os.environ:
+        return
+    if local_files_only:
+        return
+    from huggingface_hub.utils import (
+        disable_progress_bars,
+        enable_progress_bars,
+        are_progress_bars_disabled,
+    )
+
     disabled = False
     if not are_progress_bars_disabled():
         disable_progress_bars()
         disabled = True
-    pass
     _get_statistics(None)
     _get_statistics("repeat", force_download = False)
-    total_memory = torch.xpu.get_device_properties(0).total_memory if DEVICE_TYPE == "xpu" else torch.cuda.get_device_properties(0).total_memory
+    total_memory = (
+        torch.xpu.get_device_properties(0).total_memory
+        if DEVICE_TYPE == "xpu"
+        else torch.cuda.get_device_properties(0).total_memory
+    )
     vram = total_memory / 1024 / 1024 / 1024
-    if   vram <= 8 : vram = 8
-    elif vram <= 16: vram = 16
-    elif vram <= 20: vram = 20
-    elif vram <= 24: vram = 24
-    elif vram <= 40: vram = 40
-    elif vram <= 48: vram = 48
-    elif vram <= 80: vram = 80
-    else: vram = 96
+    if vram <= 8:
+        vram = 8
+    elif vram <= 16:
+        vram = 16
+    elif vram <= 20:
+        vram = 20
+    elif vram <= 24:
+        vram = 24
+    elif vram <= 40:
+        vram = 40
+    elif vram <= 48:
+        vram = 48
+    elif vram <= 80:
+        vram = 80
+    else:
+        vram = 96
     _get_statistics(f"vram-{vram}")
     _get_statistics(f"{DEVICE_COUNT if DEVICE_COUNT <= 8 else 9}")
-    if disabled: enable_progress_bars()
-pass
+    if disabled:
+        enable_progress_bars()
 
 
 # =============================================
 # Fixes Bitsandbytes to remove missing warnings
-from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
+from transformers.utils.quantization_config import (
+    BitsAndBytesConfig,
+    QuantizationMethod,
+)
+
 BitsAndBytesConfig__init__ = inspect.getsource(BitsAndBytesConfig.__init__)
 BitsAndBytesConfig__init__ = re.sub(
     r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
@@ -1043,7 +1230,9 @@ def get_statistics(local_files_only = False):
 )
 BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
 length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
-BitsAndBytesConfig__init__ = "\n".join(x[length_spaces:] for x in BitsAndBytesConfig__init__)
+BitsAndBytesConfig__init__ = "\n".join(
+    x[length_spaces:] for x in BitsAndBytesConfig__init__
+)
 BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace(
     "__init__",
     "_BitsAndBytesConfig__init__",
@@ -1052,11 +1241,17 @@ def get_statistics(local_files_only = False):
 
 if DEVICE_COUNT == 1:
     from accelerate.utils.dataclasses import DistributedType
-    def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO
+
+    def _prepare_backend(self, *args, **kwargs):
+        return None, DistributedType.NO
+
     import accelerate.state
+
     accelerate.state.PartialState._prepare_backend = _prepare_backend
-    accelerate.accelerator.Accelerator.distributed_type = lambda *args, **kwargs: DistributedType.NO
-pass
+    accelerate.accelerator.Accelerator.distributed_type = (
+        lambda *args, **kwargs: DistributedType.NO
+    )
+
 
 # to move multiple tensors to the same device
 def move_to_device(target_device, *tensors):
@@ -1079,7 +1274,6 @@ def move_to_device(target_device, *tensors):
         pass
     else:
         raise ValueError(f"Invalid target device: {target_device}")
-    pass
     moved_tensors = []
     for tensor in tensors:
         if tensor.device != target_device:
@@ -1088,61 +1282,81 @@ def move_to_device(target_device, *tensors):
             moved_tensors.append(tensor)
     return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0]
 
+
 import transformers.utils.quantization_config
-transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
+
+transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = (
+    _BitsAndBytesConfig__init__
+)
 # =============================================
 
 # Offloading to disk for modules (lm_head, embed_tokens)
 import pickle
 
-def offload_to_disk(W, model, name, temporary_location : str = "_unsloth_temporary_saved_buffers"):
+
+def offload_to_disk(
+    W, model, name, temporary_location: str = "_unsloth_temporary_saved_buffers"
+):
     file_location = os.path.join(temporary_location, model.config._name_or_path)
     if not os.path.exists(file_location):
         os.makedirs(file_location)
-    pass
 
     filename = os.path.join(file_location, f"{name}.pt")
     W = W.weight if hasattr(W, "weight") else W
-    torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
+    torch.save(
+        W,
+        filename,
+        pickle_module = pickle,
+        pickle_protocol = pickle.HIGHEST_PROTOCOL,
+    )
     # We must use weights_only = False due to pickling
-    offloaded_W = torch.load(filename, map_location = "cpu", mmap = True, weights_only = False)
+    offloaded_W = torch.load(
+        filename, map_location = "cpu", mmap = True, weights_only = False
+    )
     offloaded_W._offloaded_file_location = filename
     return offloaded_W
-pass
 
 
-def offload_input_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
-    offloaded_W = offload_to_disk(model.get_input_embeddings(), model, "input_embeddings", temporary_location)
+def offload_input_embeddings(
+    model, temporary_location: str = "_unsloth_temporary_saved_buffers"
+):
+    offloaded_W = offload_to_disk(
+        model.get_input_embeddings(), model, "input_embeddings", temporary_location
+    )
     new_input_embeddings = torch.nn.Embedding.from_pretrained(offloaded_W)
     new_input_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
     model.set_input_embeddings(new_input_embeddings)
     return
-pass
 
 
-def offload_output_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
-    offloaded_W = offload_to_disk(model.get_output_embeddings(), model, "output_embeddings", temporary_location)
+def offload_output_embeddings(
+    model, temporary_location: str = "_unsloth_temporary_saved_buffers"
+):
+    offloaded_W = offload_to_disk(
+        model.get_output_embeddings(), model, "output_embeddings", temporary_location
+    )
 
     new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
     del new_output_embeddings.weight
     new_output_embeddings.weight = offloaded_W
-    new_output_embeddings.in_features  = offloaded_W.shape[1]
+    new_output_embeddings.in_features = offloaded_W.shape[1]
     new_output_embeddings.out_features = offloaded_W.shape[0]
 
-    new_output_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
+    new_output_embeddings._offloaded_file_location = (
+        offloaded_W._offloaded_file_location
+    )
     model.set_output_embeddings(new_output_embeddings)
     return
-pass
 
 
 # Fixes a weird Torch 2.3 bug which says T4s have bfloat16
 def is_bfloat16_supported():
     return SUPPORTS_BFLOAT16
-pass
+
 
 def is_vLLM_available():
     return _is_package_available("vllm")
-pass
+
 
 # Patches models to add RoPE Scaling
 def patch_linear_scaling(
@@ -1151,17 +1365,18 @@ def patch_linear_scaling(
     scaled_rope_module = None,
     attention_module = None,
 ):
-    assert(rope_module is not None and scaled_rope_module is not None)
-    assert(attention_module is not None)
+    assert rope_module is not None and scaled_rope_module is not None
+    assert attention_module is not None
 
     rope_name = rope_module.__name__
     scaled_rope_name = scaled_rope_module.__name__
     model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
-    exec_code = \
-        f"import torch.nn as nn\n"\
-        f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
-        f"from {model_filepath} import logger, "\
+    exec_code = (
+        f"import torch.nn as nn\n"
+        f"from typing import Union, Optional, List, Any, Callable, Tuple\n"
+        f"from {model_filepath} import logger, "
         f"{model_name.title()}Attention, {model_name.title()}Config"
+    )
 
     try:
         function = inspect.getsource(attention_module.__init__)
@@ -1199,11 +1414,12 @@ def patch_linear_scaling(
     pass
     """
     fix_rope_function = fix_rope_function.format(
-        rope_function        = rope_module.__name__,
+        rope_function = rope_module.__name__,
         scaled_rope_function = scaled_rope_module.__name__,
     )
     rotary_emb = re.findall(
-        r"self\.rotary\_emb \= .+?\)", function,
+        r"self\.rotary\_emb \= .+?\)",
+        function,
         flags = re.DOTALL | re.MULTILINE,
     )
     if len(rotary_emb) == 0:
@@ -1213,7 +1429,6 @@ def patch_linear_scaling(
     function = function.replace(rotary_emb, fix_rope_function, 1)
     function = exec_code + "\n\n" + function
     return init_name, function
-pass
 
 
 # Patches for Llama-3 LlamaExtendedRotaryEmbedding
@@ -1225,21 +1440,22 @@ def patch_llama_rope_scaling(
     attention_module = None,
     longrope_module = None,
 ):
-    assert(\
-        rope_module is not None and \
-        scaled_rope_module is not None and \
-        extended_rope_module is not None
+    assert (
+        rope_module is not None
+        and scaled_rope_module is not None
+        and extended_rope_module is not None
     )
-    assert(attention_module is not None)
+    assert attention_module is not None
 
     rope_name = rope_module.__name__
     scaled_rope_name = scaled_rope_module.__name__
     model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
-    exec_code = \
-        f"import torch.nn as nn\n"\
-        f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
-        f"from {model_filepath} import logger, "\
+    exec_code = (
+        f"import torch.nn as nn\n"
+        f"from typing import Union, Optional, List, Any, Callable, Tuple\n"
+        f"from {model_filepath} import logger, "
         f"{model_name.title()}Attention, {model_name.title()}Config"
+    )
 
     try:
         function = inspect.getsource(attention_module.__init__)
@@ -1296,22 +1512,24 @@ def patch_llama_rope_scaling(
     """
 
     fix_rope_function = fix_rope_function.format(
-        rope_function          = rope_module.__name__,
-        scaled_rope_function   = scaled_rope_module.__name__,
+        rope_function = rope_module.__name__,
+        scaled_rope_function = scaled_rope_module.__name__,
         extended_rope_function = extended_rope_module.__name__,
-        longrope_rope_function = \
-            (longrope_module if longrope_module is not None else rope_module).__name__
+        longrope_rope_function = (
+            longrope_module if longrope_module is not None else rope_module
+        ).__name__,
     )
     rotary_emb = re.findall(
-        r"self\.rotary\_emb \= .+?\)", function,
+        r"self\.rotary\_emb \= .+?\)",
+        function,
         flags = re.DOTALL | re.MULTILINE,
     )
-    if len(rotary_emb) == 0: return None, function
+    if len(rotary_emb) == 0:
+        return None, function
     rotary_emb = rotary_emb[0]
     function = function.replace(rotary_emb, fix_rope_function, 1)
     function = exec_code + "\n\n" + function
     return init_name, function
-pass
 
 
 def create_boolean_mask(n = 4096, sliding_window = 2048):
@@ -1319,36 +1537,52 @@ def create_boolean_mask(n = 4096, sliding_window = 2048):
     mask = torch.ones(n, n, dtype = torch.bool)
     if sliding_window == 0:
         return torch.triu(mask, diagonal = 1, out = mask)
-    pass
     torch.triu(mask, diagonal = 0, out = mask)
     torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
     mask = mask.T
     torch.logical_not(mask, out = mask)
     return mask
-pass
 
 
 def test_mask_creation():
     from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+
     for n in range(2, 23):
         for s in range(1, 23):
-            correct_mask = AttentionMaskConverter(
-                is_causal = True,
-                sliding_window = s,
-            ).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
-            correct_mask = (correct_mask == correct_mask.min())
+            correct_mask = (
+                AttentionMaskConverter(
+                    is_causal = True,
+                    sliding_window = s,
+                )
+                .to_causal_4d(
+                    1,
+                    n,
+                    n,
+                    dtype = torch.float16,
+                )
+                .squeeze(0)
+                .squeeze(0)
+            )
+            correct_mask = correct_mask == correct_mask.min()
             our_mask = create_boolean_mask(n = n, sliding_window = s)
-            assert(torch.all(correct_mask == our_mask))
-        pass
-        correct_mask = AttentionMaskConverter(
-            is_causal = True,
-            sliding_window = None,
-        ).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
-        correct_mask = (correct_mask == correct_mask.min())
+            assert torch.all(correct_mask == our_mask)
+        correct_mask = (
+            AttentionMaskConverter(
+                is_causal = True,
+                sliding_window = None,
+            )
+            .to_causal_4d(
+                1,
+                n,
+                n,
+                dtype = torch.float16,
+            )
+            .squeeze(0)
+            .squeeze(0)
+        )
+        correct_mask = correct_mask == correct_mask.min()
         our_mask = create_boolean_mask(n = n, sliding_window = 0)
-        assert(torch.all(correct_mask == our_mask))
-    pass
-pass
+        assert torch.all(correct_mask == our_mask)
 
 
 def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
@@ -1361,48 +1595,50 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
             kwargs.pop("num_items_in_batch")
         elif "num_items_in_batch" not in inputs:
             inputs["num_items_in_batch"] = num_items_in_batch
-        pass
-    pass
 
     # Get gradient accumulation steps if possible
-    if num_items_in_batch is None and \
-        getattr(getattr(self, "args", self), "gradient_accumulation_steps", 1) != 1:
-
+    if (
+        num_items_in_batch is None
+        and getattr(getattr(self, "args", self), "gradient_accumulation_steps", 1) != 1
+    ):
         inner_model = model
-        if hasattr(inner_model, "base_model"): inner_model = inner_model. base_model
-        if hasattr(inner_model, "model"): inner_model = inner_model.model
+        if hasattr(inner_model, "base_model"):
+            inner_model = inner_model.base_model
+        if hasattr(inner_model, "model"):
+            inner_model = inner_model.model
         name = inner_model.__class__.__name__
 
         logger.warning_once(
-            f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"\
-            "Using gradient accumulation will be very slightly less accurate.\n"\
+            f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"
+            "Using gradient accumulation will be very slightly less accurate.\n"
             "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient"
         )
-    pass
     outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
     return outputs
-pass
 
 
 def patch_gradient_accumulation_fix(Trainer):
     # Fixes gradient accumulation
     # Fixes Output 0 of UnslothFusedLossBackward is a view and is being modified inplace.
     import inspect
-    if hasattr(Trainer, "get_batch_samples"):
-        if Trainer.get_batch_samples.__name__ == "_unsloth_get_batch_samples": return
-        if \
-            not inspect.getsource(Trainer.get_batch_samples).strip()\
-            .endswith("return batch_samples, num_items_in_batch"):
 
-            raise NotImplementedError("Unsloth: Please make a Github issue immediately!!")
+    if hasattr(Trainer, "get_batch_samples"):
+        if Trainer.get_batch_samples.__name__ == "_unsloth_get_batch_samples":
+            return
+        if (
+            not inspect.getsource(Trainer.get_batch_samples)
+            .strip()
+            .endswith("return batch_samples, num_items_in_batch")
+        ):
+            raise NotImplementedError(
+                "Unsloth: Please make a Github issue immediately!!"
+            )
         else:
             if Trainer.get_batch_samples.__name__ != "_unsloth_get_batch_samples":
                 Trainer.get_batch_samples = _unsloth_get_batch_samples
-            pass
 
             # Also fix passing in num_items_in_batch
             if not hasattr(Trainer, "_old_compute_loss"):
-
                 # Fix transformers 4.57.0 causing `Output 0 of UnslothFusedLossBackward is a view and is being modified inplace.`
                 function = inspect.getsource(Trainer.compute_loss)
                 if "loss *=" in function or "loss*=" in function:
@@ -1412,12 +1648,18 @@ def patch_gradient_accumulation_fix(Trainer):
 
                     # Import all variables that need importing
                     import transformers.trainer
+
                     items_in_trainer = dir(transformers.trainer)
                     good_items = []
                     for item in items_in_trainer:
-                        if item in function: good_items.append(item)
-                    pass
-                    exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
+                        if item in function:
+                            good_items.append(item)
+                    exec(
+                        "from transformers.trainer import ("
+                        + ", ".join(x for x in good_items)
+                        + ")",
+                        globals(),
+                    )
 
                     # Replace loss*= with loss = loss *
                     function = re.sub(
@@ -1427,23 +1669,21 @@ def patch_gradient_accumulation_fix(Trainer):
                     )
                     exec(function, globals())
                     Trainer.compute_loss = compute_loss
-                pass
                 Trainer._old_compute_loss = Trainer.compute_loss
                 Trainer.compute_loss = _unsloth_pre_compute_loss
-            pass
-        pass
     else:
         logger.warning_once(
-            "Unsloth: We fixed a gradient accumulation bug, "\
-            "but it seems like you don't have the latest transformers version!\n"\
-            "Please update transformers, TRL and unsloth via:\n"\
-            '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`'
+            "Unsloth: We fixed a gradient accumulation bug, "
+            "but it seems like you don't have the latest transformers version!\n"
+            "Please update transformers, TRL and unsloth via:\n"
+            "`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`"
         )
-    pass
 
     # Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps
-    if Trainer.training_step.__name__ == "_unsloth_training_step": return
-    if "num_items_in_batch" not in inspect.signature(Trainer.training_step).parameters: return
+    if Trainer.training_step.__name__ == "_unsloth_training_step":
+        return
+    if "num_items_in_batch" not in inspect.signature(Trainer.training_step).parameters:
+        return
 
     function = inspect.getsource(Trainer.training_step)
     where = function.find("def")
@@ -1452,12 +1692,16 @@ def patch_gradient_accumulation_fix(Trainer):
 
     # Import all variables that need importing
     import transformers.trainer
+
     items_in_trainer = dir(transformers.trainer)
     good_items = []
     for item in items_in_trainer:
-        if item in function: good_items.append(item)
-    pass
-    exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
+        if item in function:
+            good_items.append(item)
+    exec(
+        "from transformers.trainer import (" + ", ".join(x for x in good_items) + ")",
+        globals(),
+    )
 
     # Accelerate does / self.args.gradient_accumulation_steps internally, so if we already
     # summed it up and did the division before hand, we have to negate it.
@@ -1477,238 +1721,254 @@ def patch_gradient_accumulation_fix(Trainer):
     # Fix when num_items_in_batch is nothing
     # https://github.com/huggingface/transformers/pull/35207
     function = re.sub(
-        r"else:\n"\
-        r"([\s]{4,})self\.accelerator\.backward\(loss, \*\*kwargs\)\n"\
-        r"(.+?)if num_items_in_batch is None\:\n"\
+        r"else:\n"
+        r"([\s]{4,})self\.accelerator\.backward\(loss, \*\*kwargs\)\n"
+        r"(.+?)if num_items_in_batch is None\:\n"
         r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps",
-
-        "else:\n"\
-        "\2if num_items_in_batch is None:\n"\
-        "\3loss = loss / self.args.gradient_accumulation_steps\n"\
+        "else:\n"
+        "\2if num_items_in_batch is None:\n"
+        "\3loss = loss / self.args.gradient_accumulation_steps\n"
         "\1self.accelerator.backward(loss, **kwargs)",
-
         function,
     )
 
     exec(function, globals())
     Trainer.training_step = _unsloth_training_step
-pass
 
 
 def patch_tokenizer(model, tokenizer):
     model, tokenizer = _patch_tokenizer(model, tokenizer)
     if model is not None:
-        model.config.update({"unsloth_version" : __version__})
+        model.config.update({"unsloth_version": __version__})
     return model, tokenizer
-pass
 
 
 def patch_fast_lora():
     import peft.tuners.lora.bnb
+
     peft.tuners.lora.bnb.Linear4bit.forward = fast_lora_forward
-pass
 
 
 def unsloth_compile_transformers(
     dtype,
     model_name,
     model_types,
-    token                   = None,
-    revision                = None,
-    trust_remote_code       = False,
-    sdpa_dynamic_mask       = True,
-    sdpa_bool_masks         = True,
-    sdpa_gqa_replace        = True,
-    sdpa_dynamic_compile    = True,
-    compile_attention       = True,
-    disable_causal_masks    = True,
-    compile_torch_modules   = True,
-    compile_custom_modules  = True,
-    compile_function_calls  = True,
-    fuse_lm_head            = True,
-    gradient_checkpointing  = True,
-    manual_replacements     = True,
-    fast_lora_forwards      = True,
-    fast_residual_stream    = True,
-    accurate_accumulation   = True,
-    epilogue_fusion         = True,
-    max_autotune            = False,
-    shape_padding           = True,
-    cudagraphs              = False,
-    debug                   = False,
-    fullgraph               = True,
-    import_from_cache       = False,
-    disable                 = False,
-    return_logits           = False,
-    unsloth_force_compile   = False,
+    token = None,
+    revision = None,
+    trust_remote_code = False,
+    sdpa_dynamic_mask = True,
+    sdpa_bool_masks = True,
+    sdpa_gqa_replace = True,
+    sdpa_dynamic_compile = True,
+    compile_attention = True,
+    disable_causal_masks = True,
+    compile_torch_modules = True,
+    compile_custom_modules = True,
+    compile_function_calls = True,
+    fuse_lm_head = True,
+    gradient_checkpointing = True,
+    manual_replacements = True,
+    fast_lora_forwards = True,
+    fast_residual_stream = True,
+    accurate_accumulation = True,
+    epilogue_fusion = True,
+    max_autotune = False,
+    shape_padding = True,
+    cudagraphs = False,
+    debug = False,
+    fullgraph = True,
+    import_from_cache = False,
+    disable = False,
+    return_logits = False,
+    unsloth_force_compile = False,
 ):
     if Version(torch_version) < Version("2.4.0"):
         print(
-            "="*30 + \
-            "Unsloth: Unfortunately Unsloth vision and other newer optimized models need Torch 2.4 or later.\n"\
-            f"You have Torch version {torch_version}. Please upgrade your Torch version by visiting https://pytorch.org/\n"\
+            "="
+            * 30
+            + "Unsloth: Unfortunately Unsloth vision and other newer optimized models need Torch 2.4 or later.\n"
+            f"You have Torch version {torch_version}. Please upgrade your Torch version by visiting https://pytorch.org/\n"
             "For now your models will not get optimized, but will still work for now!"
         )
         return
-    pass
     if trust_remote_code and unsloth_force_compile == False:
         print(
-            "Unsloth: We can't trace models if `trust_remote_code = True`, "\
+            "Unsloth: We can't trace models if `trust_remote_code = True`, "
             "so turning off some optimizations!"
         )
         return model_types, False
     model_types = list(dict().fromkeys(model_types).keys())
-    if disable: return model_types, False
+    if disable:
+        return model_types, False
 
     supports_sdpa = [True]
     for model_type in model_types:
         _unsloth_compile_transformers(
             model_type,
-            sdpa_dynamic_mask      = sdpa_dynamic_mask,
-            sdpa_bool_masks        = sdpa_bool_masks,
-            sdpa_gqa_replace       = sdpa_gqa_replace,
-            sdpa_dynamic_compile   = sdpa_dynamic_compile,
-            compile_attention      = compile_attention,
-            disable_causal_masks   = disable_causal_masks,
-            compile_torch_modules  = compile_torch_modules,
+            sdpa_dynamic_mask = sdpa_dynamic_mask,
+            sdpa_bool_masks = sdpa_bool_masks,
+            sdpa_gqa_replace = sdpa_gqa_replace,
+            sdpa_dynamic_compile = sdpa_dynamic_compile,
+            compile_attention = compile_attention,
+            disable_causal_masks = disable_causal_masks,
+            compile_torch_modules = compile_torch_modules,
             compile_custom_modules = compile_custom_modules,
             compile_function_calls = compile_function_calls,
-            fuse_lm_head           = fuse_lm_head,
+            fuse_lm_head = fuse_lm_head,
             gradient_checkpointing = gradient_checkpointing,
-            manual_replacements    = manual_replacements,
-            fast_lora_forwards     = fast_lora_forwards,
-            fast_residual_stream   = fast_residual_stream,
-            accurate_accumulation  = accurate_accumulation,
-            epilogue_fusion        = epilogue_fusion,
-            max_autotune           = max_autotune,
-            shape_padding          = shape_padding,
-            cudagraphs             = cudagraphs,
-            debug                  = debug,
-            fullgraph              = fullgraph,
-            import_from_cache      = import_from_cache,
-            disable                = disable,
-            return_logits          = return_logits,
-            supports_sdpa          = supports_sdpa,
+            manual_replacements = manual_replacements,
+            fast_lora_forwards = fast_lora_forwards,
+            fast_residual_stream = fast_residual_stream,
+            accurate_accumulation = accurate_accumulation,
+            epilogue_fusion = epilogue_fusion,
+            max_autotune = max_autotune,
+            shape_padding = shape_padding,
+            cudagraphs = cudagraphs,
+            debug = debug,
+            fullgraph = fullgraph,
+            import_from_cache = import_from_cache,
+            disable = disable,
+            return_logits = return_logits,
+            supports_sdpa = supports_sdpa,
         )
-    pass
     # Redo patches which override compiler
     for temporary_patch in TEMPORARY_PATCHES:
         temporary_patch()
     return model_types, supports_sdpa[0]
-pass
+
 
 # We need an empty logits flag to warn people logits will not be returned anymore unless asked ie
 # os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
-LOGITS_ERROR_STRING = \
-    "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\
-    'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\
-    "```\nimport os\n"\
-    "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\
-    "trainer.train()\n```\n"\
+LOGITS_ERROR_STRING = (
+    "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "
+    'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'
+    "```\nimport os\n"
+    "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"
+    "trainer.train()\n```\n"
     "No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!"
+)
+
+
+def raise_logits_error(*args, **kwargs):
+    raise NotImplementedError(LOGITS_ERROR_STRING)
+
+
+def return_none(*args, **kwargs):
+    return None
+
 
-def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING)
-def return_none(*args, **kwargs): return None
 class EmptyLogits:
-    def __init__(self): return
-    def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error
+    def __init__(self):
+        return
+
+    def raise_getattr_error(self, attr):
+        return return_none if attr == "to" else raise_logits_error
+
     __getitem__ = raise_logits_error
     __getattr__ = raise_getattr_error
-    def __repr__(self): return LOGITS_ERROR_STRING
-    def __str__ (self): return LOGITS_ERROR_STRING
-pass
+
+    def __repr__(self):
+        return LOGITS_ERROR_STRING
+
+    def __str__(self):
+        return LOGITS_ERROR_STRING
+
+
 EMPTY_LOGITS = EmptyLogits()
 functions = dir(torch.Tensor)
 for j, function in enumerate(functions):
     if function.startswith("__") and function.endswith("__"):
-        exec(f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals())
-        try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
-        except: continue
-pass
+        exec(
+            f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals()
+        )
+        try:
+            exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
+        except:
+            continue
 
 
 def validate_loftq_config(loftq_config, lora_dropout, bias, init_lora_weights, model):
     from peft import LoraConfig
 
-    if loftq_config is None: loftq_config = {}
+    if loftq_config is None:
+        loftq_config = {}
 
     signature = str(inspect.signature(LoraConfig))
-    SUPPORTS_LOFTQ  = "loftq_config" in signature
+    SUPPORTS_LOFTQ = "loftq_config" in signature
 
     if lora_dropout != 0:
         logger.warning_once(
-            f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"\
+            f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"
             f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
         )
-    pass
 
     if bias != "none":
         logger.warning_once(
-            f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"\
+            f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"
             f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
         )
-    pass
 
-    if not (type(init_lora_weights) is bool or \
-        init_lora_weights == "gaussian" or init_lora_weights == "loftq"):
+    if not (
+        type(init_lora_weights) is bool
+        or init_lora_weights == "gaussian"
+        or init_lora_weights == "loftq"
+    ):
         raise ValueError(
             'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq"].'
         )
-    pass
 
     if init_lora_weights == "loftq":
-
         if not SUPPORTS_LOFTQ:
             import peft
+
             raise RuntimeError(
-                f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"\
-                "Please install PEFT 0.7.2 or higher.\n"\
+                f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"
+                "Please install PEFT 0.7.2 or higher.\n"
                 "You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
             )
-        pass
 
         if loftq_config == {}:
             from peft import LoftQConfig
+
             logger.warning_once(
-                "Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\
+                "Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"
                 "We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
             )
             loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
-        pass
 
         if hasattr(model.config, "quantization_config"):
             raise ValueError(
-                "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\
+                "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"
                 "Reload your model without any quantization by setting `load_in_4bit = False`."
             )
-        pass
-    pass
 
     return loftq_config
 
+
 def fast_inference_setup(model_name, model_config):
     fast_inference = True
     if not is_vLLM_available():
-        logger.warning_once("Unsloth: vLLM is not installed! Will use Unsloth inference!")
+        logger.warning_once(
+            "Unsloth: vLLM is not installed! Will use Unsloth inference!"
+        )
         fast_inference = False
-    pass
     from unsloth_zoo.vllm_utils import (
         patch_vllm,
         vllm_dynamic_quant_supported,
     )
+
     patch_vllm()
     if model_name.endswith("unsloth-bnb-4bit"):
         if not vllm_dynamic_quant_supported(model_name, model_config):
             # Instead use -bnb-4bit variant
             logger.warning_once(
-                f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"\
+                f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"
                 f"we do not yet support fast inference for {model_name}"
             )
-            model_name = model_name[:-len("unsloth-bnb-4bit")] + "bnb-4bit"
-        pass
-    pass
+            model_name = model_name[: -len("unsloth-bnb-4bit")] + "bnb-4bit"
     return fast_inference, model_name
 
+
 def patch_peft_fast_inference(model):
     vllm_engine = getattr(model.model, "vllm_engine", None)
     if vllm_engine is not None:
@@ -1718,40 +1978,50 @@ def patch_peft_fast_inference(model):
 
         # Also saving and loading LoRA
         from unsloth_zoo.vllm_utils import save_lora, load_lora
+
         model.save_lora = functools.partial(save_lora, model)
         model.load_lora = functools.partial(load_lora, model)
-    pass
+
 
 def error_out_no_vllm(*args, **kwargs):
-    raise NotImplementedError("Unsloth: vLLM is not yet supported for fast inference for this model! Please use `.generate` instead")
+    raise NotImplementedError(
+        "Unsloth: vLLM is not yet supported for fast inference for this model! Please use `.generate` instead"
+    )
 
 
 try:
     from torchao.core.config import AOBaseConfig
+
     try:
         from torchao.quantization import Int4WeightOnlyConfig
     except:
         print("Unsloth: TorchAO changed `torchao.quantization.Int4WeightOnlyConfig`")
         Int4WeightOnlyConfig = None
-    pass
 except:
     AOBaseConfig = None
     Int4WeightOnlyConfig = None
-    pass
+
+
 @dataclass
 class TorchAOConfig:
-    qat_scheme : str = "int4"
-    base_config : AOBaseConfig = field(
+    qat_scheme: str = "int4"
+    base_config: AOBaseConfig = field(
         default_factory = lambda: Int4WeightOnlyConfig(group_size = 128)
     )
-    group_size : int = 128
+    group_size: int = 128
     filter_fn: Optional[Callable] = None
+
     def __post_init__(self):
         if self.filter_fn is None:
-            self.filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= self.group_size
-pass
+            self.filter_fn = (
+                lambda m, _: isinstance(m, torch.nn.Linear)
+                and m.in_features >= self.group_size
+            )
+
 
-def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: Union[str, TorchAOConfig]) -> torch.nn.Module:
+def _prepare_model_for_qat(
+    model: torch.nn.Module, qat_scheme: Union[str, TorchAOConfig]
+) -> torch.nn.Module:
     """
     Transform a model for Quantization-Aware Training (QAT) during fine-tuning.
 
@@ -1772,25 +2042,41 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: Union[str, TorchA
         base_config = None
         if qat_scheme == "fp8-int4":
             from torchao.quantization import Float8DynamicActivationInt4WeightConfig
+
             group_size = 128
             base_config = Float8DynamicActivationInt4WeightConfig()
-            filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
+            filter_fn = (
+                lambda m, _: isinstance(m, torch.nn.Linear)
+                and m.in_features >= group_size
+            )
         elif qat_scheme == "fp8-fp8":
             from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
-            base_config = Float8DynamicActivationFloat8WeightConfig(granularity = PerRow())
+
+            base_config = Float8DynamicActivationFloat8WeightConfig(
+                granularity = PerRow()
+            )
         elif qat_scheme == "int8-int4":
             from torchao.quantization import Int8DynamicActivationIntxWeightConfig
+
             group_size = 32
-            base_config = Int8DynamicActivationIntxWeightConfig(weight_dtype = torch.int4, weight_granularity = PerGroup(group_size))
-            filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
+            base_config = Int8DynamicActivationIntxWeightConfig(
+                weight_dtype = torch.int4, weight_granularity = PerGroup(group_size)
+            )
+            filter_fn = (
+                lambda m, _: isinstance(m, torch.nn.Linear)
+                and m.in_features >= group_size
+            )
         elif qat_scheme == "int4":
             from torchao.quantization import Int4WeightOnlyConfig
+
             group_size = 128
             base_config = Int4WeightOnlyConfig(group_size = group_size)
-            filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
+            filter_fn = (
+                lambda m, _: isinstance(m, torch.nn.Linear)
+                and m.in_features >= group_size
+            )
         else:
             raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
-        pass
         # Save TorchAO schemes
         torchao_config = TorchAOConfig(
             qat_scheme = qat_scheme,
@@ -1800,10 +2086,10 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: Union[str, TorchA
         )
     else:
         torchao_config = qat_scheme
-        qat_scheme  = torchao_config.qat_scheme
+        qat_scheme = torchao_config.qat_scheme
         base_config = torchao_config.base_config
-        group_size  = torchao_config.group_size
-        filter_fn   = torchao_config.filter_fn
+        group_size = torchao_config.group_size
+        filter_fn = torchao_config.filter_fn
 
     # Save Torchao metadata everywhere
     inner_model = model
@@ -1814,14 +2100,18 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: Union[str, TorchA
     # Quantize with TorchAO
     quantize_(model, QATConfig(base_config, step = "prepare"), filter_fn = filter_fn)
     return model
-pass
+
 
 def patch_hf_quantizer():
     # To tell hf trainer that the quantized model is trainable
     def make_trainable(self):
         return True
+
     try:
-        from transformers.quantizers.quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer
+        from transformers.quantizers.quantizer_finegrained_fp8 import (
+            FineGrainedFP8HfQuantizer,
+        )
+
         FineGrainedFP8HfQuantizer.is_trainable = property(make_trainable)
         FineGrainedFP8HfQuantizer.is_qat_trainable = property(make_trainable)
     except Exception as e:
@@ -1829,10 +2119,11 @@ def make_trainable(self):
 
     try:
         from transformers.quantizers.quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
+
         FbgemmFp8HfQuantizer.is_trainable = property(make_trainable)
         FbgemmFp8HfQuantizer.is_qat_trainable = property(make_trainable)
     except Exception as e:
         logger.warning(f"Failed to patch FbgemmFp8HfQuantizer. Error {e}")
-pass
+
 
 patch_hf_quantizer()
diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py
index b22e5ca22..b460c6237 100644
--- a/unsloth/models/cohere.py
+++ b/unsloth/models/cohere.py
@@ -16,6 +16,7 @@
 from ._utils import __version__
 from unsloth_zoo.hf_utils import dtype_from_config
 from unsloth_zoo.utils import _get_dtype
+
 try:
     from transformers.models.cohere.modeling_cohere import (
         CohereAttention,
@@ -28,20 +29,20 @@
     )
 except:
     from packaging.version import Version
+
     transformers_version = Version(transformers_version)
     if not transformers_version >= Version("4.42"):
         raise ImportError(
-            f"Unsloth: Your transformers version of {transformers_version} does not support Cohere.\n"\
-            f"The minimum required version is 4.42.3.\n"\
-            f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
-            f"to obtain the latest transformers build, then restart this session."\
+            f"Unsloth: Your transformers version of {transformers_version} does not support Cohere.\n"
+            f"The minimum required version is 4.42.3.\n"
+            f'Try `pip install --upgrade "transformers>=4.42.3"`\n'
+            f"to obtain the latest transformers build, then restart this session."
         )
-    pass
-pass
 
 from transformers.modeling_attn_mask_utils import (
     _prepare_4d_causal_attention_mask_for_sdpa,
 )
+
 # For Pytorch 2.1.1
 try:
     from transformers.models.cohere.modeling_cohere import (
@@ -49,9 +50,8 @@
         CohereFlashAttention2,
     )
 except:
-    CohereSdpaAttention   = CohereAttention
+    CohereSdpaAttention = CohereAttention
     CohereFlashAttention2 = CohereAttention
-pass
 
 
 def fast_layernorm_inference(self, X, out_weight = None):
@@ -63,24 +63,23 @@ def fast_layernorm_inference(self, X, out_weight = None):
     out_weight[:] = self.weight
     XX *= out_weight
     return XX.to(X.dtype)
-pass
 
 
 # QK norm in Cohere
 def CohereAttention_fast_forward(
     self,
-    hidden_states:        torch.Tensor,
-    causal_mask:          Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:       Optional[torch.Tensor] = None,
-    position_ids:         Optional[torch.LongTensor] = None,
-    past_key_value:       Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:    bool = False,
-    use_cache:            bool = False,
-    padding_mask:         Optional[torch.LongTensor] = None,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+    padding_mask: Optional[torch.LongTensor] = None,
     position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-    *args, **kwargs,
+    *args,
+    **kwargs,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-
     # Clear inference
     if hasattr(self, "paged_attention"):
         del self.paged_attention_K
@@ -92,24 +91,22 @@ def CohereAttention_fast_forward(
         del self.attention
         del self.q_norm_out_weight
         del self.k_norm_out_weight
-    pass
 
     bsz, q_len, _ = hidden_states.size()
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
-    assert(n_kv_heads * n_groups == n_heads)
+    head_dim = self.head_dim
+    assert n_kv_heads * n_groups == n_heads
 
     Q, K, V = self.apply_qkv(self, hidden_states)
-    Q = Q.view(bsz, q_len, n_heads,    head_dim).transpose(1, 2)
+    Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
     K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
     V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
     if self.use_qk_norm:
         Q = fast_layernorm_compiled(self.q_norm, Q)
         K = fast_layernorm_compiled(self.k_norm, K)
-    pass
 
     kv_seq_len = K.shape[-2]
     if past_key_value is not None:
@@ -121,16 +118,14 @@ def CohereAttention_fast_forward(
     else:
         cos, sin = cos[position_ids], sin[position_ids]
         Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
-    pass
 
     if past_key_value is not None:
         K = torch.cat([past_key_value[0], K], dim = 2)
         V = torch.cat([past_key_value[1], V], dim = 2)
-    pass
     past_key_value = (K, V) if use_cache else None
 
     # Attention module
-    if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None):
+    if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None:
         # Xformers memory efficient attention
         # Also has Flash Attention v2 dispatching
         Q = Q.transpose(1, 2)
@@ -139,8 +134,8 @@ def CohereAttention_fast_forward(
 
         # Group query attention
         if n_groups != 1:
-            K = K  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
-            V = V  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
+            K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+            V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
             K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
             V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
             if hidden_states.requires_grad:
@@ -148,7 +143,6 @@ def CohereAttention_fast_forward(
                 V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
             else:
                 Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
-        pass
         A = xformers_attention(Q, K, V, attn_bias = causal_mask)
         A = A.view(bsz, q_len, n_heads, head_dim)
 
@@ -160,56 +154,66 @@ def CohereAttention_fast_forward(
     else:
         # Grouped query attention
         if n_groups != 1:
-            K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
-            V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+            K = K[:, :, None, :, :].expand(
+                bsz, n_kv_heads, n_groups, kv_seq_len, head_dim
+            )
+            V = V[:, :, None, :, :].expand(
+                bsz, n_kv_heads, n_groups, kv_seq_len, head_dim
+            )
             K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
             V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
-        pass
         # Must be contiguous or else results are False!
         # https://github.com/pytorch/pytorch/issues/112577
         Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
         # Needs (batch_size, n_heads, seq_len, head_dim)
         # is_casual and attention_mask must not be both set!
-        A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
+        A = scaled_dot_product_attention(
+            Q, K, V, attn_mask = attention_mask, is_causal = False
+        )
         # Go back to (batch_size, seq_len, n_heads, head_dim)
         A = A.transpose(1, 2).contiguous()
-    pass
-    attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
+    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
     attn_output = self.apply_o(self, attn_output)
     attn_weights = None
     return attn_output, attn_weights, past_key_value
-pass
 
 
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
 def CohereDecoderLayer_fast_forward(
     self,
-    hidden_states:        torch.Tensor,
-    causal_mask:          Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:       Optional[torch.Tensor] = None,
-    position_ids:         Optional[torch.LongTensor] = None,
-    past_key_value:       Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:    Optional[bool] = False,
-    use_cache:            Optional[bool] = False,
-    padding_mask:         Optional[torch.LongTensor] = None,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: Optional[bool] = False,
+    use_cache: Optional[bool] = False,
+    padding_mask: Optional[torch.LongTensor] = None,
     position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-    *args, **kwargs,
+    *args,
+    **kwargs,
 ):
-    if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
-        out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
+    if use_cache and hasattr(
+        self, "_flag_for_generation"
+    ):  # past_key_value is not None:
+        out_weight = torch.empty(
+            self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0"
+        )
 
         # Self Attention
         residual = hidden_states
-        hidden_states = fast_layernorm_inference(self.input_layernorm, hidden_states, out_weight)
+        hidden_states = fast_layernorm_inference(
+            self.input_layernorm, hidden_states, out_weight
+        )
         hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states=hidden_states,
-            causal_mask=causal_mask,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_value=past_key_value,
-            output_attentions=output_attentions,
-            use_cache=use_cache,
-            padding_mask=padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
         )
 
         # Fully Connected
@@ -221,105 +225,132 @@ def CohereDecoderLayer_fast_forward(
         residual = hidden_states
         hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states)
         hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states=hidden_states,
-            causal_mask=causal_mask,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_value=past_key_value,
-            output_attentions=output_attentions,
-            use_cache=use_cache,
-            padding_mask=padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
         )
 
         # Fully Connected
         hidden_states_mlp = self.mlp(hidden_states)
         hidden_states = residual + hidden_states_attention + hidden_states_mlp
-    pass
 
     outputs = (hidden_states,)
-    if output_attentions: outputs += (self_attn_weights,)
-    if use_cache: outputs += (present_key_value,)
+    if output_attentions:
+        outputs += (self_attn_weights,)
+    if use_cache:
+        outputs += (present_key_value,)
     return outputs
-pass
 
 
 from math import sqrt as math_sqrt
-KV_CACHE_INCREMENT = 256 # KV Cache update size
+
+KV_CACHE_INCREMENT = 256  # KV Cache update size
 torch_nn_functional_softmax = torch.nn.functional.softmax
 torch_matmul = torch.matmul
 
+
 def CohereAttention_fast_forward_inference(
     self,
-    hidden_states:  torch.Tensor,
+    hidden_states: torch.Tensor,
     past_key_value: Optional[Tuple[torch.Tensor]],
     position_ids,
     do_prefill = False,
     attention_mask = None,
 ):
-
     Xn = hidden_states
     bsz, _, hd = hidden_states.size()
     K1, V1 = past_key_value
     dtype = Xn.dtype
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
+    head_dim = self.head_dim
     # assert(n_kv_heads * n_groups == n_heads)
 
     hidden_size = self.config.hidden_size
-    attention_size = n_heads*head_dim
+    attention_size = n_heads * head_dim
     seq_len = K1.shape[-2]
     kv_seq_len = seq_len + 1
 
     # Prefill phase
     # if not hasattr(self, "paged_attention"):
     if do_prefill:
-        self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
+        self.paged_attention = torch.empty(
+            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
+            dtype = dtype,
+            device = "cuda:0",
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
         self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
         self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
-        self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
-        self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
-        self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+        self.temp_QA = torch.empty(
+            (2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0"
+        )
+        self.temp_KV = torch.empty(
+            (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = "cuda:0"
+        )
+        self.RH_Q = torch.empty(
+            (bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0"
+        )
 
         # Mistral Nemo 12b has weird dimensions
         if attention_size != hidden_size:
-            self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
+            self.temp_O = torch.empty(
+                (1, bsz, hidden_size), dtype = dtype, device = "cuda:0"
+            )
         else:
-            self.temp_O = self.temp_QA[1][:,:,:hidden_size]
-        pass
+            self.temp_O = self.temp_QA[1][:, :, :hidden_size]
 
-        self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
+        self.attention = torch.empty(
+            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len),
+            dtype = dtype,
+            device = "cuda:0",
+        )
         self.scalar = 1.0 / math_sqrt(self.head_dim)
         self.half_head_dim = head_dim // 2
         # Cohere has QK layernorms
         if self.use_qk_norm:
-            self.q_norm_out_weight = torch.empty(self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0")
-            self.k_norm_out_weight = torch.empty(self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0")
+            self.q_norm_out_weight = torch.empty(
+                self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0"
+            )
+            self.k_norm_out_weight = torch.empty(
+                self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0"
+            )
         else:
             self.q_norm_out_weight = None
             self.k_norm_out_weight = None
-        pass
     elif kv_seq_len >= self.paged_attention.shape[0]:
-        self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
-        self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
-    pass
+        self.paged_attention.resize_(
+            (
+                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
+                2,
+                bsz,
+                n_kv_heads,
+                head_dim,
+            )
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
+        self.attention.resize_(
+            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
+        )
 
     Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
     Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
     Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
-    Qn = Qn.view(bsz, 1, n_heads,    head_dim).transpose(1, 2)
+    Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
     Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
     Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
     if self.use_qk_norm:
         Q = fast_layernorm_inference(self.q_norm, Q, self.q_norm_out_weight)
         K = fast_layernorm_inference(self.k_norm, K, self.k_norm_out_weight)
-    pass
 
     # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
     # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
@@ -329,16 +360,18 @@ def CohereAttention_fast_forward_inference(
     h = self.half_head_dim
 
     RH_Q = self.RH_Q
-    RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
-    RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
-    torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
+    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
+    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
+    torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h])
     Qn *= cos
     Qn.addcmul_(RH_Q, sin)
 
-    RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
-    RH_K[:,:,:,:h] = Kn[:,:,:,h:]
-    RH_K[:,:,:,h:] = Kn[:,:,:,:h]
-    torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
+    RH_K = RH_Q[
+        :, :n_kv_heads, :, :
+    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+    RH_K[:, :, :, :h] = Kn[:, :, :, h:]
+    RH_K[:, :, :, h:] = Kn[:, :, :, :h]
+    torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h])
     Kn *= cos
     Kn.addcmul_(RH_K, sin)
 
@@ -355,40 +388,46 @@ def CohereAttention_fast_forward_inference(
     if sliding_window is not None and kv_seq_len > sliding_window:
         # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
         slicing_tokens = 1 - sliding_window
-        Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
-        Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
+        Knn = Kn[:, :, slicing_tokens:, :]  # .contiguous()
+        Vnn = Vn[:, :, slicing_tokens:, :]  # .contiguous()
     else:
         Knn, Vnn = Kn, Vn
-    pass
 
     # Grouped query attention
     _, _, cached_len, _ = Knn.shape
     if n_groups != 1:
-        Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
-        Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+        Knn = Knn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
+        Vnn = Vnn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
         Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
         Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
-    pass
     # else:
     #     Knn, Vnn = Knn, Vnn
     # pass
 
     # Attention
     if bsz == 1:
-        Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
+        Qn *= self.scalar  # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
         # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
-        A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
+        A = torch_matmul(
+            Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
+        )
         # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
-        A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
+        A[:] = torch_nn_functional_softmax(
+            A, dim = -1, dtype = torch.float32
+        )  # .to(A.dtype)
         A = torch_matmul(A, Vnn, out = Qn)
     else:
-        A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
-    pass
+        A = scaled_dot_product_attention(
+            Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False
+        )
     A = A.transpose(1, 2)
     A = A.reshape(bsz, 1, attention_size)
     A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
     return A, (Kn, Vn)
-pass
 
 
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
@@ -400,8 +439,15 @@ def CohereModel_fast_forward_inference(
     position_ids,
     attention_mask = None,
 ):
-    out_weights = tuple(torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT))
-    input_ids = input_ids[:,:self.max_seq_length]
+    out_weights = tuple(
+        torch.empty_like(
+            self.model.layers[0].input_layernorm.weight,
+            dtype = torch.float32,
+            device = torch.device(x),
+        )
+        for x in range(DEVICE_COUNT)
+    )
+    input_ids = input_ids[:, : self.max_seq_length]
     hidden_states = self.model.embed_tokens(input_ids)
     hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
     bsz, q_len, hd = hidden_states.shape
@@ -416,7 +462,6 @@ def CohereModel_fast_forward_inference(
         )
     else:
         attention_mask = None
-    pass
 
     next_decoder_cache = []
     for idx, decoder_layer in enumerate(self.model.layers):
@@ -425,14 +470,18 @@ def CohereModel_fast_forward_inference(
             device_index, hidden_states, position_ids
         )
         residual = hidden_states
-        hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weights[device_index])
-        hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference(
-            decoder_layer.self_attn,
-            hidden_states = hidden_states,
-            past_key_value = past_key_values[idx],
-            position_ids = position_ids,
-            attention_mask = attention_mask,
-            do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
+        hidden_states = fast_layernorm_inference(
+            decoder_layer.input_layernorm, hidden_states, out_weights[device_index]
+        )
+        hidden_states_attention, present_key_value = (
+            CohereAttention_fast_forward_inference(
+                decoder_layer.self_attn,
+                hidden_states = hidden_states,
+                past_key_value = past_key_values[idx],
+                position_ids = position_ids,
+                attention_mask = attention_mask,
+                do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
+            )
         )
 
         hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states)
@@ -441,8 +490,9 @@ def CohereModel_fast_forward_inference(
         hidden_states = residual
 
         next_decoder_cache.append(present_key_value)
-    pass
-    hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weights[device_index])
+    hidden_states = fast_layernorm_inference(
+        self.model.norm, hidden_states, out_weights[device_index]
+    )
 
     return BaseModelOutputWithPast(
         last_hidden_state = hidden_states,
@@ -450,34 +500,34 @@ def CohereModel_fast_forward_inference(
         hidden_states = [],
         attentions = [],
     )
-pass
 
 
 class FastCohereModel(FastLlamaModel):
-
     @staticmethod
     def pre_patch():
         init_name, function = patch_linear_scaling(
-            model_name         = "cohere",
-            rope_module        = LlamaRotaryEmbedding,
+            model_name = "cohere",
+            rope_module = LlamaRotaryEmbedding,
             scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
-            attention_module   = CohereAttention,
+            attention_module = CohereAttention,
         )
         if init_name is not None:
             exec(function, globals())
-            CohereAttention.__init__  = eval(init_name)
-        pass
-        CohereAttention      .forward = CohereAttention_fast_forward
-        CohereSdpaAttention  .forward = CohereAttention_fast_forward
+            CohereAttention.__init__ = eval(init_name)
+        CohereAttention.forward = CohereAttention_fast_forward
+        CohereSdpaAttention.forward = CohereAttention_fast_forward
         CohereFlashAttention2.forward = CohereAttention_fast_forward
-        CohereDecoderLayer   .forward = CohereDecoderLayer_fast_forward
-        CohereModel          .forward = LlamaModel_fast_forward
-        CohereForCausalLM    .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference)
-        PeftModelForCausalLM .forward = PeftModel_fast_forward
+        CohereDecoderLayer.forward = CohereDecoderLayer_fast_forward
+        CohereModel.forward = LlamaModel_fast_forward
+        CohereForCausalLM.forward = CausalLM_fast_forward(
+            CohereModel_fast_forward_inference
+        )
+        PeftModelForCausalLM.forward = PeftModel_fast_forward
         fix_prepare_inputs_for_generation(CohereForCausalLM)
 
         import transformers.models.cohere.modeling_cohere
-        transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding
+
+        transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = (
+            LlamaRotaryEmbedding
+        )
         return
-    pass
-pass
diff --git a/unsloth/models/dpo.py b/unsloth/models/dpo.py
index 7fc02fcf7..def047e39 100644
--- a/unsloth/models/dpo.py
+++ b/unsloth/models/dpo.py
@@ -17,6 +17,10 @@
     "PatchKTOTrainer",
 ]
 
-def PatchDPOTrainer(): return
 
-def PatchKTOTrainer(): return
+def PatchDPOTrainer():
+    return
+
+
+def PatchKTOTrainer():
+    return
diff --git a/unsloth/models/falcon_h1.py b/unsloth/models/falcon_h1.py
index 4449f65d3..70223a3c2 100644
--- a/unsloth/models/falcon_h1.py
+++ b/unsloth/models/falcon_h1.py
@@ -22,6 +22,7 @@
     LlamaLinearScalingRotaryEmbedding,
     _LlamaModel_fast_forward_inference,
 )
+
 try:
     from transformers.models.falcon_h1.modeling_falcon_h1 import (
         FalconH1Attention,
@@ -32,21 +33,24 @@
     )
 except:
     from transformers import __version__ as transformers_version
+
     transformers_version = Version(transformers_version)
-    if not transformers_version >= Version("4.53.0"): #TODO: Update when transformers is updated
+    if not transformers_version >= Version(
+        "4.53.0"
+    ):  # TODO: Update when transformers is updated
         raise ImportError(
-            f"Unsloth: Your transformers version of {transformers_version} does not support FalconH1.\n"\
-            f"The minimum required version is 4.53.0.\n"\
-            f'Try `pip install --upgrade "transformers>=4.53.0"`\n'\
-            f"to obtain the latest transformers build, then restart this session."\
+            f"Unsloth: Your transformers version of {transformers_version} does not support FalconH1.\n"
+            f"The minimum required version is 4.53.0.\n"
+            f'Try `pip install --upgrade "transformers>=4.53.0"`\n'
+            f"to obtain the latest transformers build, then restart this session."
         )
-    pass
 from transformers.modeling_attn_mask_utils import (
     _prepare_4d_causal_attention_mask_for_sdpa,
 )
 from transformers.utils import (
     is_torchdynamo_compiling,
 )
+
 # For Pytorch 2.1.1
 try:
     from transformers.models.falcon_h1.modeling_falcon_h1 import (
@@ -56,24 +60,25 @@
     # if we are on a old version of transformers technically it should fail in the try except above
     # but if somehow we make it here, we need to raise an error since FalconH1Attention is not available
     # or renamed
-    raise ImportError("Unsloth: Could not import FalconH1Attention from transformers.models.falcon_h1.modeling_falcon_h1.")
-pass
+    raise ImportError(
+        "Unsloth: Could not import FalconH1Attention from transformers.models.falcon_h1.modeling_falcon_h1."
+    )
 
 
 def FalconH1Attention_fast_forward(
     self,
-    hidden_states:       torch.Tensor,
-    causal_mask:         Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:      Optional[torch.Tensor] = None,
-    position_ids:        Optional[torch.LongTensor] = None,
-    past_key_value:      Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:   bool = False,
-    use_cache:           bool = False,
-    padding_mask:        Optional[torch.LongTensor] = None,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+    padding_mask: Optional[torch.LongTensor] = None,
     position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-    *args, **kwargs,
+    *args,
+    **kwargs,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-
     # Clear inference
     if hasattr(self, "paged_attention"):
         del self.paged_attention_K
@@ -83,19 +88,22 @@ def FalconH1Attention_fast_forward(
         del self.temp_KV
         del self.RH_Q
         del self.attention
-    pass
 
     bsz, q_len, _ = hidden_states.size()
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
-    assert(n_kv_heads * n_groups == n_heads)
+    head_dim = self.head_dim
+    assert n_kv_heads * n_groups == n_heads
 
     Q, K, V = self.apply_qkv(self, hidden_states)
-    Q = Q.view(bsz, q_len, n_heads,    head_dim)#.transpose(1, 2) # we will transpose after normalisation
-    K = K.view(bsz, q_len, n_kv_heads, head_dim)#.transpose(1, 2) # we will transpose after normalisation
+    Q = Q.view(
+        bsz, q_len, n_heads, head_dim
+    )  # .transpose(1, 2) # we will transpose after normalisation
+    K = K.view(
+        bsz, q_len, n_kv_heads, head_dim
+    )  # .transpose(1, 2) # we will transpose after normalisation
     V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
 
     # Falcon H1 multiplies key states by a multiplier
@@ -126,11 +134,10 @@ def FalconH1Attention_fast_forward(
     if past_key_value is not None:
         K = torch.cat([past_key_value[0], K], dim = 2)
         V = torch.cat([past_key_value[1], V], dim = 2)
-    pass
     past_key_value = (K, V) if use_cache else None
 
     # Attention module
-    if (not HAS_FLASH_ATTENTION and attention_mask is None):
+    if not HAS_FLASH_ATTENTION and attention_mask is None:
         # Xformers memory efficient attention
         Q = Q.transpose(1, 2)
         K = K.transpose(1, 2)
@@ -139,8 +146,8 @@ def FalconH1Attention_fast_forward(
         Q_M = bsz * q_len
 
         # Group query attention
-        K = K  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
-        V = V  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
+        K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+        V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
         K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
         V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
         if hidden_states.requires_grad:
@@ -149,7 +156,6 @@ def FalconH1Attention_fast_forward(
         else:
             # Xformers does support the forward pass though
             Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
-        pass
 
         A = xformers_attention(Q, K, V, attn_bias = causal_mask)
         A = A.view(bsz, q_len, n_heads, head_dim)
@@ -174,67 +180,70 @@ def FalconH1Attention_fast_forward(
         Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
         # Needs (batch_size, n_heads, seq_len, head_dim)
         # is_casual and attention_mask must not be both set!
-        A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
+        A = scaled_dot_product_attention(
+            Q, K, V, attn_mask = attention_mask, is_causal = False
+        )
         # Go back to (batch_size, seq_len, n_heads, head_dim)
         A = A.transpose(1, 2).contiguous()
-    pass
 
-    attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
+    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
     attn_output = self.apply_o(self, attn_output)
     attn_weights = None
     return attn_output, attn_weights, past_key_value
-pass
+
 
 torch_matmul = torch.matmul
+
+
 def FalconH1Attention_fast_forward_inference(
     self,
-    hidden_states:  torch.Tensor,
+    hidden_states: torch.Tensor,
     past_key_value: Optional[Tuple[torch.Tensor]],
     position_ids,
     do_prefill = False,
     attention_mask = None,
 ):
     """
-        https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
-        Fast inference using KV cache.
-        QK^T can be computed in 4 chunks
+    https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
+    Fast inference using KV cache.
+    QK^T can be computed in 4 chunks
 
-        [Q, q] @ [K, k].T where q, k are the new tokens.
-        [QK^T, Qk^T]
-        [qK^T, qk^T]
+    [Q, q] @ [K, k].T where q, k are the new tokens.
+    [QK^T, Qk^T]
+    [qK^T, qk^T]
 
-        Since the attention mask wipes Qk^T, we just get
-        [QK^T,    0]
-        [qK^T, qk^T]
+    Since the attention mask wipes Qk^T, we just get
+    [QK^T,    0]
+    [qK^T, qk^T]
 
-        Since softmax is row-wise, we get
-        softmax([QK^T,    0])
-        softmax([qK^T, qk^T])
+    Since softmax is row-wise, we get
+    softmax([QK^T,    0])
+    softmax([qK^T, qk^T])
 
-        We then multiply by   [V]
-                              [v]
-        softmax([QK^T,    0]) [softmax(QK^T)V] *
-        softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
+    We then multiply by   [V]
+                          [v]
+    softmax([QK^T,    0]) [softmax(QK^T)V] *
+    softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
 
-        But notice * [softmax(QK^T)V] is just the last attention.
-        We just need to compute the last final row.
+    But notice * [softmax(QK^T)V] is just the last attention.
+    We just need to compute the last final row.
 
-        This means we can pass in a row of Q, but we need to
-        remember K and V, which are called the KV cache.
+    This means we can pass in a row of Q, but we need to
+    remember K and V, which are called the KV cache.
     """
     Xn = hidden_states
     bsz, _, hd = hidden_states.size()
     K1, V1 = past_key_value
     dtype = Xn.dtype
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
+    head_dim = self.head_dim
     # assert(n_kv_heads * n_groups == n_heads)
 
     hidden_size = self.config.hidden_size
-    attention_size = n_heads*head_dim
+    attention_size = n_heads * head_dim
     seq_len = K1.shape[-2]
     kv_seq_len = seq_len + 1
 
@@ -242,38 +251,60 @@ def FalconH1Attention_fast_forward_inference(
     # if not hasattr(self, "paged_attention"):
     device = hidden_states.device
     if do_prefill:
-        self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = device)
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
+        self.paged_attention = torch.empty(
+            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
+            dtype = dtype,
+            device = device,
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
         self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
         self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
-        self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device)
-        self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device)
+        self.temp_QA = torch.empty(
+            (2, bsz, 1, attention_size), dtype = dtype, device = device
+        )
+        self.temp_KV = torch.empty(
+            (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
+        )
         self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
 
         # Mistral Nemo 12b has weird dimensions
         if attention_size != hidden_size:
             self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
         else:
-            self.temp_O = self.temp_QA[1][:,:,:hidden_size]
-        pass
+            self.temp_O = self.temp_QA[1][:, :, :hidden_size]
 
-        self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device)
+        self.attention = torch.empty(
+            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
+        )
         self.scalar = 1.0 / math_sqrt(self.head_dim)
         self.half_head_dim = head_dim // 2
     elif kv_seq_len >= self.paged_attention.shape[0]:
-        self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
-        self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
-    pass
+        self.paged_attention.resize_(
+            (
+                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
+                2,
+                bsz,
+                n_kv_heads,
+                head_dim,
+            )
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
+        self.attention.resize_(
+            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
+        )
 
     Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
     Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
     Kn = Kn * self.config.key_multiplier
     Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
-    Qn = Qn.view(bsz, 1, n_heads,    head_dim)#.transpose(1, 2) # we will transpose after normalisation
-    Kn = Kn.view(bsz, 1, n_kv_heads, head_dim)#.transpose(1, 2) # we will transpose after normalisation
+    Qn = Qn.view(
+        bsz, 1, n_heads, head_dim
+    )  # .transpose(1, 2) # we will transpose after normalisation
+    Kn = Kn.view(
+        bsz, 1, n_kv_heads, head_dim
+    )  # .transpose(1, 2) # we will transpose after normalisation
     Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
 
     Qn = Qn.transpose(1, 2)
@@ -291,16 +322,18 @@ def FalconH1Attention_fast_forward_inference(
     h = self.half_head_dim
 
     RH_Q = self.RH_Q
-    RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
-    RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
-    RH_Q[:,:,:,:h].neg_() # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
+    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
+    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
+    RH_Q[:, :, :, :h].neg_()  # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
     Qn *= cos
     Qn.addcmul_(RH_Q, sin)
 
-    RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
-    RH_K[:,:,:,:h] = Kn[:,:,:,h:]
-    RH_K[:,:,:,h:] = Kn[:,:,:,:h]
-    RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
+    RH_K = RH_Q[
+        :, :n_kv_heads, :, :
+    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+    RH_K[:, :, :, :h] = Kn[:, :, :, h:]
+    RH_K[:, :, :, h:] = Kn[:, :, :, :h]
+    RH_K[:, :, :, :h].neg_()  # torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
     Kn *= cos
     Kn.addcmul_(RH_K, sin)
 
@@ -317,59 +350,69 @@ def FalconH1Attention_fast_forward_inference(
     if sliding_window is not None and kv_seq_len > sliding_window:
         # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
         slicing_tokens = 1 - sliding_window
-        Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
-        Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
+        Knn = Kn[:, :, slicing_tokens:, :]  # .contiguous()
+        Vnn = Vn[:, :, slicing_tokens:, :]  # .contiguous()
     else:
         Knn, Vnn = Kn, Vn
-    pass
 
     # Grouped query attention
     _, _, cached_len, _ = Knn.shape
     if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1:
-        Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
-        Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+        Knn = Knn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
+        Vnn = Vnn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
         Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
         Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
-    pass
     # else:
     #     Knn, Vnn = Knn, Vnn
     # pass
 
     # Attention
     if bsz == 1:
-        Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
+        Qn *= self.scalar  # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
         # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
-        A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
+        A = torch_matmul(
+            Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
+        )
         # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
-        A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
+        A[:] = torch_nn_functional_softmax(
+            A, dim = -1, dtype = torch.float32
+        )  # .to(A.dtype)
         A = torch_matmul(A, Vnn, out = Qn)
     else:
         if SDPA_HAS_GQA:
-            A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True)
+            A = scaled_dot_product_attention(
+                Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True
+            )
         else:
-            A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
-    pass
+            A = scaled_dot_product_attention(
+                Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False
+            )
     A = A.transpose(1, 2)
     A = A.reshape(bsz, 1, attention_size)
     A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
     return A, (Kn, Vn)
-pass
+
 
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon_h1/modeling_falcon_h1.py
 def FalconH1DecoderLayer_fast_forward(
     self,
-    hidden_states:       torch.Tensor,
-    causal_mask          = None,
-    attention_mask:      Optional[torch.Tensor] = None,
-    mamba_attention_mask:      Optional[torch.Tensor] = None,
-    position_ids:        Optional[torch.LongTensor] = None,
-    cache_position:        Optional[torch.LongTensor] = None,
-    past_key_value:      Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:   Optional[bool] = False,
-    use_cache:           Optional[bool] = False,
-    padding_mask:        Optional[torch.LongTensor] = None,
+    hidden_states: torch.Tensor,
+    causal_mask = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    mamba_attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    cache_position: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: Optional[bool] = False,
+    use_cache: Optional[bool] = False,
+    padding_mask: Optional[torch.LongTensor] = None,
     position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-    *args, **kwargs,
+    *args,
+    **kwargs,
 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
     """
     Args:
@@ -386,25 +429,27 @@ def FalconH1DecoderLayer_fast_forward(
     """
     if use_cache and hasattr(self, "_flag_for_generation"):
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
+        hidden_states = fast_rms_layernorm_inference(
+            self.input_layernorm, hidden_states
+        )
         attention_hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states       = hidden_states,
-            causal_mask         = causal_mask,
-            attention_mask      = attention_mask,
-            position_ids        = position_ids,
-            past_key_value      = past_key_value,
-            output_attentions   = output_attentions,
-            use_cache           = use_cache,
-            padding_mask        = padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
             position_embeddings = position_embeddings,
         )
         attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
 
         mamba_hidden_states = self.mamba(
-            hidden_states=hidden_states,
-            cache_params=past_key_value,
-            cache_position=cache_position,
-            attention_mask=mamba_attention_mask,
+            hidden_states = hidden_states,
+            cache_params = past_key_value,
+            cache_position = cache_position,
+            attention_mask = mamba_attention_mask,
         )
         mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
 
@@ -414,7 +459,9 @@ def FalconH1DecoderLayer_fast_forward(
 
         # Fully Connected
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
+        hidden_states = fast_rms_layernorm_inference(
+            self.post_attention_layernorm, hidden_states
+        )
         hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
         hidden_states += residual
     else:
@@ -422,22 +469,22 @@ def FalconH1DecoderLayer_fast_forward(
         hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
 
         mamba_hidden_states = self.mamba(
-            hidden_states=hidden_states,
-            cache_params=past_key_value,
-            cache_position=cache_position,
-            attention_mask=mamba_attention_mask,
+            hidden_states = hidden_states,
+            cache_params = past_key_value,
+            cache_position = cache_position,
+            attention_mask = mamba_attention_mask,
         )
         mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
 
         attention_hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states       = hidden_states,
-            causal_mask         = causal_mask,
-            attention_mask      = attention_mask,
-            position_ids        = position_ids,
-            past_key_value      = past_key_value,
-            output_attentions   = output_attentions,
-            use_cache           = use_cache,
-            padding_mask        = padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
             position_embeddings = position_embeddings,
         )
         attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
@@ -452,15 +499,19 @@ def FalconH1DecoderLayer_fast_forward(
         hidden_states = fast_rms_layernorm(self.pre_ff_layernorm, hidden_states)
         hidden_states = self.feed_forward(hidden_states)
         hidden_states = residual + hidden_states
-    pass
 
     outputs = (hidden_states,)
-    if output_attentions: outputs += (self_attn_weights,)
-    if use_cache: outputs += (present_key_value,)
+    if output_attentions:
+        outputs += (self_attn_weights,)
+    if use_cache:
+        outputs += (present_key_value,)
     return outputs
-pass
 
-def _FalconH1_fast_forward_inference(attention_fast_forward_inference=FalconH1Attention_fast_forward_inference, mlp_fast_forward_inference=fast_swiglu_inference):
+
+def _FalconH1_fast_forward_inference(
+    attention_fast_forward_inference = FalconH1Attention_fast_forward_inference,
+    mlp_fast_forward_inference = fast_swiglu_inference,
+):
     # This makes the attention and MLP customisable.
     # Now for models like qwen3 or cohere which use custom attention operations, we can use this function
     def FalconH1Model_fast_forward_inference_custom(
@@ -472,7 +523,7 @@ def FalconH1Model_fast_forward_inference_custom(
         attention_mask = None,
         mamba_attention_mask = None,
     ):
-        input_ids = input_ids[:,:self.max_seq_length]
+        input_ids = input_ids[:, : self.max_seq_length]
         bsz, q_len = input_ids.shape
         hd = self.config.hidden_size
         mlp_size = self.config.intermediate_size
@@ -483,7 +534,7 @@ def FalconH1Model_fast_forward_inference_custom(
 
         X = X.to(_get_dtype(dtype_from_config(self.config)))
         bsz, q_len, hd = X.shape
-        assert(q_len == 1)
+        assert q_len == 1
         # Get saved buffers to reduce memory movement
         residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
         _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
@@ -502,12 +553,11 @@ def FalconH1Model_fast_forward_inference_custom(
             )
         else:
             attention_mask = None
-        pass
 
         next_decoder_cache = []
 
         for idx, decoder_layer in enumerate(self.model.layers):
-            residual.copy_(X) # residual = X
+            residual.copy_(X)  # residual = X
             X = fast_rms_layernorm_inference(
                 decoder_layer.input_layernorm,
                 X,
@@ -515,27 +565,31 @@ def FalconH1Model_fast_forward_inference_custom(
                 XX2 = XX2,
                 variance = variance,
             )
-            attention_hidden_states, present_key_value = attention_fast_forward_inference(
-                decoder_layer.self_attn,
-                hidden_states = X * decoder_layer.attention_in_multiplier,
-                past_key_value = past_key_values[idx],
-                position_ids = position_ids,
-                attention_mask = attention_mask,
-                do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
+            attention_hidden_states, present_key_value = (
+                attention_fast_forward_inference(
+                    decoder_layer.self_attn,
+                    hidden_states = X * decoder_layer.attention_in_multiplier,
+                    past_key_value = past_key_values[idx],
+                    position_ids = position_ids,
+                    attention_mask = attention_mask,
+                    do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
+                )
+            )
+            attention_hidden_states = (
+                attention_hidden_states * decoder_layer.attn_out_multiplier
             )
-            attention_hidden_states = attention_hidden_states * decoder_layer.attn_out_multiplier
             mamba_hidden_states = decoder_layer.mamba(
-                hidden_states=X,
-                cache_params=present_key_value,
-                cache_position=cache_position,
-                attention_mask=mamba_attention_mask,
+                hidden_states = X,
+                cache_params = present_key_value,
+                cache_position = cache_position,
+                attention_mask = mamba_attention_mask,
             )
             mamba_hidden_states = mamba_hidden_states * decoder_layer.ssm_out_multiplier
             X = mamba_hidden_states + attention_hidden_states
 
             X += residual
 
-            residual.copy_(X) # residual = X
+            residual.copy_(X)  # residual = X
             X = fast_rms_layernorm_inference(
                 decoder_layer.pre_ff_layernorm,
                 X,
@@ -549,12 +603,11 @@ def FalconH1Model_fast_forward_inference_custom(
                 temp_gate = temp_gate,
                 temp_up = temp_up,
                 gate_multiplier = gate_multiplier,
-                down_multiplier = down_multiplier
+                down_multiplier = down_multiplier,
             )
             X += residual
 
             next_decoder_cache.append(present_key_value)
-        pass
         X = fast_rms_layernorm_inference(
             self.model.final_layernorm,
             X,
@@ -569,20 +622,22 @@ def FalconH1Model_fast_forward_inference_custom(
             hidden_states = [],
             attentions = [],
         )
-    pass
+
     return FalconH1Model_fast_forward_inference_custom
 
-#Separate prepare_inputs_for_generation for Hybrid FalconH1
+
+# Separate prepare_inputs_for_generation for Hybrid FalconH1
 def _fast_prepare_inputs_for_generation(
     self,
     input_ids,
-    past_key_values=None,
-    attention_mask=None,
-    inputs_embeds=None,
-    cache_position=None,
-    position_ids=None,
-    use_cache=True,
-    **kwargs,):
+    past_key_values = None,
+    attention_mask = None,
+    inputs_embeds = None,
+    cache_position = None,
+    position_ids = None,
+    use_cache = True,
+    **kwargs,
+):
     # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
     empty_past_kv = past_key_values is None
 
@@ -594,12 +649,15 @@ def _fast_prepare_inputs_for_generation(
     if not empty_past_kv:
         if (
             inputs_embeds is not None  # Exception 1
-            or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1])  # Exception 3
+            or (
+                is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]
+            )  # Exception 3
         ):
             input_ids = input_ids[:, -cache_position.shape[0] :]
-        elif input_ids.shape[1] != cache_position.shape[0]:  # Default case (the "else", a no op, is Exception 2)
+        elif (
+            input_ids.shape[1] != cache_position.shape[0]
+        ):  # Default case (the "else", a no op, is Exception 2)
             input_ids = input_ids[:, cache_position]
-    pass
     # TODO: Wire up Cache to work for inference.
     # else:
     #     past_key_values = FalconHybridMambaAttentionDynamicCache(
@@ -622,7 +680,9 @@ def _fast_prepare_inputs_for_generation(
     if inputs_embeds is not None and empty_past_kv:
         model_inputs = {"inputs_embeds": inputs_embeds}
     else:
-        model_inputs = {"input_ids": input_ids.contiguous()}  # `contiguous()` needed for compilation use cases
+        model_inputs = {
+            "input_ids": input_ids.contiguous()
+        }  # `contiguous()` needed for compilation use cases
 
     model_inputs.update(
         {
@@ -635,34 +695,32 @@ def _fast_prepare_inputs_for_generation(
         }
     )
     return model_inputs
-pass
 
 
 def fix_prepare_inputs_for_generation(module):
     # Fix prepare_inputs_for_generation
     if hasattr(module, "prepare_inputs_for_generation"):
-            module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation
-    pass
-pass
+        module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation
 
-class FastFalconH1Model(FastLlamaModel):
 
+class FastFalconH1Model(FastLlamaModel):
     @staticmethod
     def pre_patch():
         init_name, function = patch_linear_scaling(
-            model_name         = "FalconH1",
-            rope_module        = LlamaRotaryEmbedding,
+            model_name = "FalconH1",
+            rope_module = LlamaRotaryEmbedding,
             scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
-            attention_module   = FalconH1Attention,
+            attention_module = FalconH1Attention,
         )
         if init_name is not None:
             exec(function, globals())
-            FalconH1Attention.__init__  = eval(init_name)
-        pass
-        FalconH1Attention      .forward = FalconH1Attention_fast_forward
-        FalconH1DecoderLayer   .forward = FalconH1DecoderLayer_fast_forward
-        FalconH1Model          .forward = LlamaModel_fast_forward
-        FalconH1ForCausalLM    .forward = CausalLM_fast_forward(_FalconH1_fast_forward_inference(FalconH1Attention_fast_forward_inference))
+            FalconH1Attention.__init__ = eval(init_name)
+        FalconH1Attention.forward = FalconH1Attention_fast_forward
+        FalconH1DecoderLayer.forward = FalconH1DecoderLayer_fast_forward
+        FalconH1Model.forward = LlamaModel_fast_forward
+        FalconH1ForCausalLM.forward = CausalLM_fast_forward(
+            _FalconH1_fast_forward_inference(FalconH1Attention_fast_forward_inference)
+        )
         PeftModelForCausalLM.forward = PeftModel_fast_forward
         fix_prepare_inputs_for_generation(FalconH1ForCausalLM)
 
@@ -672,39 +730,38 @@ def pre_patch():
         # https://github.com/huggingface/transformers/pull/27931
         # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
         import transformers.models.falcon_h1.modeling_falcon_h1
-        transformers.models.falcon_h1.modeling_falcon_h1.FalconH1RotaryEmbedding = LlamaRotaryEmbedding
-        return
-    pass
 
+        transformers.models.falcon_h1.modeling_falcon_h1.FalconH1RotaryEmbedding = (
+            LlamaRotaryEmbedding
+        )
+        return
 
     @staticmethod
-    def from_pretrained(  #TODO: Change after release
-        model_name        = "Qwen/FalconH1-7B",
-        max_seq_length    = 4096,
-        dtype             = None,
-        load_in_4bit      = True,
-        token             = None,
-        device_map        = "sequential",
-        rope_scaling      = None,
-        fix_tokenizer     = True,
-        model_patcher     = None,
-        tokenizer_name    = None,
+    def from_pretrained(  # TODO: Change after release
+        model_name = "Qwen/FalconH1-7B",
+        max_seq_length = 4096,
+        dtype = None,
+        load_in_4bit = True,
+        token = None,
+        device_map = "sequential",
+        rope_scaling = None,
+        fix_tokenizer = True,
+        model_patcher = None,
+        tokenizer_name = None,
         trust_remote_code = False,
         **kwargs,
     ):
         return FastLlamaModel.from_pretrained(
-            model_name        = model_name,
-            max_seq_length    = max_seq_length,
-            dtype             = dtype,
-            load_in_4bit      = load_in_4bit,
-            token             = token,
-            device_map        = device_map,
-            rope_scaling      = rope_scaling,
-            fix_tokenizer     = fix_tokenizer,
-            model_patcher     = FastFalconH1Model,
-            tokenizer_name    = tokenizer_name,
+            model_name = model_name,
+            max_seq_length = max_seq_length,
+            dtype = dtype,
+            load_in_4bit = load_in_4bit,
+            token = token,
+            device_map = device_map,
+            rope_scaling = rope_scaling,
+            fix_tokenizer = fix_tokenizer,
+            model_patcher = FastFalconH1Model,
+            tokenizer_name = tokenizer_name,
             trust_remote_code = trust_remote_code,
             **kwargs,
         )
-    pass
-pass
diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py
index db869b63f..cf5bdd8eb 100644
--- a/unsloth/models/gemma.py
+++ b/unsloth/models/gemma.py
@@ -30,20 +30,20 @@
     )
 except:
     from packaging.version import Version
+
     transformers_version = Version(transformers_version)
     if not transformers_version >= Version("4.38"):
         raise ImportError(
-            f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
-            f"The minimum required version is 4.38.\n"\
-            f'Try `pip install --upgrade "transformers>=4.38"`\n'\
-            f"to obtain the latest transformers build, then restart this session."\
+            f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"
+            f"The minimum required version is 4.38.\n"
+            f'Try `pip install --upgrade "transformers>=4.38"`\n'
+            f"to obtain the latest transformers build, then restart this session."
         )
-    pass
-pass
 
 from transformers.modeling_attn_mask_utils import (
     _prepare_4d_causal_attention_mask_for_sdpa,
 )
+
 # For Pytorch 2.1.1
 try:
     from transformers.models.gemma.modeling_gemma import (
@@ -51,12 +51,13 @@
         GemmaFlashAttention2,
     )
 except:
-    GemmaSdpaAttention   = GemmaAttention
+    GemmaSdpaAttention = GemmaAttention
     GemmaFlashAttention2 = GemmaAttention
-pass
 
 
 torch_nn_functional_gelu = torch.nn.functional.gelu
+
+
 def fast_geglu_inference(self, X):
     # gate = self.gate_proj(X)
     # up   = self.up_proj(X)
@@ -64,84 +65,97 @@ def fast_geglu_inference(self, X):
     # mlp_size = self.config.intermediate_size
     # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
 
-    gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
-    up   = fast_linear_forward(self.  up_proj, X)#, out = temp[1])
+    gate = fast_linear_forward(self.gate_proj, X)  # , out = temp[0])
+    up = fast_linear_forward(self.up_proj, X)  # , out = temp[1])
     gate = torch_nn_functional_gelu(gate, approximate = "tanh")
     gate *= up
 
     # X = self.down_proj(gate)
-    down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
+    down = fast_linear_forward(self.down_proj, gate, out = up[:, :, :hd])
     return down
-pass
 
 
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
 def GemmaDecoderLayer_fast_forward(
     self,
-    hidden_states:        torch.Tensor,
-    causal_mask:          Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:       Optional[torch.Tensor] = None,
-    position_ids:         Optional[torch.LongTensor] = None,
-    past_key_value:       Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:    Optional[bool] = False,
-    use_cache:            Optional[bool] = False,
-    padding_mask:         Optional[torch.LongTensor] = None,
-    *args, **kwargs,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: Optional[bool] = False,
+    use_cache: Optional[bool] = False,
+    padding_mask: Optional[torch.LongTensor] = None,
+    *args,
+    **kwargs,
 ):
-    if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
-        out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
+    if use_cache and hasattr(
+        self, "_flag_for_generation"
+    ):  # past_key_value is not None:
+        out_weight = torch.empty(
+            self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0"
+        )
 
         # Self Attention
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            self.input_layernorm, hidden_states, out_weight
+        )
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states=hidden_states,
-            causal_mask=causal_mask,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_value=past_key_value,
-            output_attentions=output_attentions,
-            use_cache=use_cache,
-            padding_mask=padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
         )
         hidden_states += residual
 
         # Fully Connected
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            self.post_attention_layernorm, hidden_states, out_weight
+        )
         hidden_states = fast_geglu_inference(self.mlp, hidden_states)
         hidden_states += residual
     else:
         residual = hidden_states
-        hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
+        hidden_states = fast_rms_layernorm(
+            self.input_layernorm, hidden_states, gemma = True
+        )
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states=hidden_states,
-            causal_mask=causal_mask,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_value=past_key_value,
-            output_attentions=output_attentions,
-            use_cache=use_cache,
-            padding_mask=padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
         )
         hidden_states = residual + hidden_states
 
         # Fully Connected
         residual = hidden_states
-        hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
+        hidden_states = fast_rms_layernorm(
+            self.post_attention_layernorm, hidden_states, gemma = True
+        )
         hidden_states = self.mlp(hidden_states)
         hidden_states = residual + hidden_states
-    pass
 
     outputs = (hidden_states,)
-    if output_attentions: outputs += (self_attn_weights,)
-    if use_cache: outputs += (present_key_value,)
+    if output_attentions:
+        outputs += (self_attn_weights,)
+    if use_cache:
+        outputs += (present_key_value,)
     return outputs
-pass
 
 
 from math import sqrt as math_sqrt
 
+
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
 # @torch.inference_mode
 def GemmaModel_fast_forward_inference(
@@ -151,13 +165,22 @@ def GemmaModel_fast_forward_inference(
     position_ids,
     attention_mask = None,
 ):
-    out_weights = tuple(torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT))
-    input_ids = input_ids[:,:self.max_seq_length]
+    out_weights = tuple(
+        torch.empty_like(
+            self.model.layers[0].input_layernorm.weight,
+            dtype = torch.float32,
+            device = torch.device(x),
+        )
+        for x in range(DEVICE_COUNT)
+    )
+    input_ids = input_ids[:, : self.max_seq_length]
     hidden_states = self.model.embed_tokens(input_ids)
     hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
     # 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
     # 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
-    hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
+    hidden_states *= torch.tensor(
+        math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype
+    )
 
     bsz, q_len, hd = hidden_states.shape
     seq_len = past_key_values[0][0].shape[-2]
@@ -168,7 +191,6 @@ def GemmaModel_fast_forward_inference(
             hidden_states,
             seq_len,
         )
-    pass
 
     next_decoder_cache = []
     for idx, decoder_layer in enumerate(self.model.layers):
@@ -178,7 +200,9 @@ def GemmaModel_fast_forward_inference(
         )
 
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weights[device_index])
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            decoder_layer.input_layernorm, hidden_states, out_weights[device_index]
+        )
         hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
             decoder_layer.self_attn,
             hidden_states = hidden_states,
@@ -190,13 +214,18 @@ def GemmaModel_fast_forward_inference(
         hidden_states += residual
 
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weights[device_index])
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            decoder_layer.post_attention_layernorm,
+            hidden_states,
+            out_weights[device_index],
+        )
         hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
         hidden_states += residual
 
         next_decoder_cache.append(present_key_value)
-    pass
-    hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weights[device_index])
+    hidden_states = fast_rms_layernorm_inference_gemma(
+        self.model.norm, hidden_states, out_weights[device_index]
+    )
 
     return BaseModelOutputWithPast(
         last_hidden_state = hidden_states,
@@ -204,7 +233,6 @@ def GemmaModel_fast_forward_inference(
         hidden_states = [],
         attentions = [],
     )
-pass
 
 
 # Follows line by line https://github.com/google-deepmind/gemma/blob/main/gemma/positional_embeddings.py#L45
@@ -213,35 +241,51 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
     # Fixes https://github.com/huggingface/transformers/pull/28837
     # https://github.com/microsoft/DeepSpeed/issues/4932
     # The precision of RoPE buffers is not correct, so we cast to int64.
-    def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
-        config = None, # [TODO] Hack to pass in config - need to remove later
+    def __init__(
+        self,
+        dim = None,
+        max_position_embeddings = 2048,
+        base = 10000,
+        device = None,
+        config = None,  # [TODO] Hack to pass in config - need to remove later
     ):
         super().__init__()
         if config is not None:
             # [TODO] Hack to pass in config - need to remove later
             base = config.rope_theta
-            partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+            partial_rotary_factor = (
+                config.partial_rotary_factor
+                if hasattr(config, "partial_rotary_factor")
+                else 1.0
+            )
             dim = getattr(config, "head_dim", None)
-            if dim is None: dim = int((config.hidden_size // config.num_attention_heads))
+            if dim is None:
+                dim = int((config.hidden_size // config.num_attention_heads))
             device = "cuda"
             max_position_embeddings = config.max_position_embeddings
-        pass
         self.dim = dim
         self.max_position_embeddings = max_position_embeddings
         self.base = base
         # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
         self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
-        self.multi_gpu_cos_cached = [None]*DEVICE_COUNT
-        self.multi_gpu_sin_cached = [None]*DEVICE_COUNT
+        self.multi_gpu_cos_cached = [None] * DEVICE_COUNT
+        self.multi_gpu_sin_cached = [None] * DEVICE_COUNT
 
         # Build here to make `torch.jit.trace` work.
         for device in range(DEVICE_COUNT):
-            self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device), dtype=torch.get_default_dtype())
+            self._set_cos_sin_cache(
+                seq_len = self.current_rope_size,
+                device = torch.device(device),
+                dtype = torch.get_default_dtype(),
+            )
 
         # dummy so that patch_utils doesn't fail for now
-        self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype())
-        self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype())
-    pass
+        self.cos_cached = torch.empty(
+            1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype()
+        )
+        self.sin_cached = torch.empty(
+            1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype()
+        )
 
     def _set_cos_sin_cache(self, seq_len, device, dtype):
         # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
@@ -253,23 +297,24 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
             torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
         )
         timescale = self.base**freq_exponents
-        positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
+        positions = torch.arange(
+            self.current_rope_size, device = "cpu", dtype = torch.int64
+        ).float()
         radians_new = positions[..., None] / timescale[None, None, :]
         radians_new = radians_new.squeeze(0)
 
         emb = torch.cat((radians_new, radians_new), dim = -1)
         # We must do RoPE in float32!
-        cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype)
-        sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype)
+        cos = emb.cos().to(device = device, non_blocking = True)  # , dtype = dtype)
+        sin = emb.sin().to(device = device, non_blocking = True)  # , dtype = dtype)
         self.multi_gpu_cos_cached[device.index] = cos
         self.multi_gpu_sin_cached[device.index] = sin
         return cos, sin
-    pass
 
-    def forward(self, x, position_ids=None, seq_len=None):
+    def forward(self, x, position_ids = None, seq_len = None):
         # x: [bs, num_attention_heads, seq_len, head_size]
         if seq_len is not None and seq_len > self.current_rope_size:
-            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+            self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
 
         device_index = x.device.index
 
@@ -277,35 +322,48 @@ def forward(self, x, position_ids=None, seq_len=None):
             self.multi_gpu_cos_cached[device_index][:seq_len],
             self.multi_gpu_sin_cached[device_index][:seq_len],
         )
-    pass
 
     def get_cached(self, seq_len = None, device_index = None):
         if device_index is None:
             device_index = torch.cuda.current_device()
-        return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[device_index]
-    pass
+        return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[
+            device_index
+        ]
 
     def extend_rope_embedding(self, x, seq_len):
-        if seq_len <= self.current_rope_size: return
+        if seq_len <= self.current_rope_size:
+            return
         # Iteratively grow by increments of 8192
         self.current_rope_size = math.ceil(seq_len / 8192) * 8192
         for device in range(DEVICE_COUNT):
-            self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device), dtype = x.dtype)
-    pass
-pass
+            self._set_cos_sin_cache(
+                self.current_rope_size, device = torch.device(device), dtype = x.dtype
+            )
 
 
 class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
     """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
     # Fixes https://github.com/huggingface/transformers/pull/28837
     # https://github.com/microsoft/DeepSpeed/issues/4932
     # The precision of RoPE buffers is not correct, so we cast to int64.
-    def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
-        config = None, # [TODO] Hack to pass in config - need to remove later
+    def __init__(
+        self,
+        dim = None,
+        max_position_embeddings = 2048,
+        base = 10000,
+        device = None,
+        scaling_factor = 1.0,
+        config = None,  # [TODO] Hack to pass in config - need to remove later
     ):
         self.scaling_factor = scaling_factor
-        super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
-    pass
+        super().__init__(
+            dim = dim,
+            max_position_embeddings = max_position_embeddings,
+            base = base,
+            device = device,
+            config = config,
+        )
 
     def _set_cos_sin_cache(self, seq_len, device, dtype):
         # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
@@ -317,42 +375,42 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
             torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
         )
         timescale = self.base**freq_exponents
-        positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
-        positions = positions /  self.scaling_factor
+        positions = torch.arange(
+            self.current_rope_size, device = "cpu", dtype = torch.int64
+        ).float()
+        positions = positions / self.scaling_factor
         radians_new = positions[..., None] / timescale[None, None, :]
         radians_new = radians_new.squeeze(0)
 
         emb = torch.cat((radians_new, radians_new), dim = -1)
         # We must do RoPE in float32!
-        cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype)
-        sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype)
+        cos = emb.cos().to(device = device, non_blocking = True)  # , dtype = dtype)
+        sin = emb.sin().to(device = device, non_blocking = True)  # , dtype = dtype)
         self.multi_gpu_cos_cached[device.index] = cos
         self.multi_gpu_sin_cached[device.index] = sin
         return cos, sin
-    pass
-pass
 
 
 class FastGemmaModel(FastLlamaModel):
-
     @staticmethod
     def pre_patch():
         init_name, function = patch_linear_scaling(
-            model_name         = "gemma",
-            rope_module        = GemmaFixedRotaryEmbedding,
+            model_name = "gemma",
+            rope_module = GemmaFixedRotaryEmbedding,
             scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
-            attention_module   = GemmaAttention,
+            attention_module = GemmaAttention,
         )
         if init_name is not None:
             exec(function, globals())
-            GemmaAttention.__init__  = eval(init_name)
-        pass
-        GemmaAttention      .forward = LlamaAttention_fast_forward
-        GemmaSdpaAttention  .forward = LlamaAttention_fast_forward
+            GemmaAttention.__init__ = eval(init_name)
+        GemmaAttention.forward = LlamaAttention_fast_forward
+        GemmaSdpaAttention.forward = LlamaAttention_fast_forward
         GemmaFlashAttention2.forward = LlamaAttention_fast_forward
-        GemmaDecoderLayer   .forward = GemmaDecoderLayer_fast_forward
-        GemmaModel          .forward = LlamaModel_fast_forward
-        GemmaForCausalLM    .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference)
+        GemmaDecoderLayer.forward = GemmaDecoderLayer_fast_forward
+        GemmaModel.forward = LlamaModel_fast_forward
+        GemmaForCausalLM.forward = CausalLM_fast_forward(
+            GemmaModel_fast_forward_inference
+        )
         PeftModelForCausalLM.forward = PeftModel_fast_forward
         fix_prepare_inputs_for_generation(GemmaForCausalLM)
 
@@ -362,15 +420,18 @@ def pre_patch():
         # https://github.com/huggingface/transformers/pull/27931
         # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
         import transformers.models.gemma.modeling_gemma
-        transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = GemmaFixedRotaryEmbedding
-        return
-    pass
 
+        transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = (
+            GemmaFixedRotaryEmbedding
+        )
+        return
 
     @staticmethod
     def post_patch(model, tokenizer):
         # Gemma does not downcast RoPE
-        model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = False)
+        model, tokenizer = patch_model_and_tokenizer(
+            model, tokenizer, downcast_rope = False
+        )
 
         # Add 1 to weight
         # return output * (1 + self.weight)
@@ -384,7 +445,6 @@ def post_patch(model, tokenizer):
                 param.requires_grad_(True)
             else:
                 param.requires_grad_(False)
-        pass
 
         # Patch RMS Layernorm
         for name, module in model.named_modules():
@@ -395,14 +455,14 @@ def post_patch(model, tokenizer):
                 # Leave + 1 to Triton kernel itself
                 # module.weight += 1.0 # return output * (1 + self.weight)
                 if not hasattr(module, "variance_epsilon"):
-                    module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
-        pass
+                    module.variance_epsilon = (
+                        module.eps
+                    )  # Gemma doesn't use variance_epsilon
 
         # Clear deleted GPU items
         import gc
+
         for _ in range(3):
             gc.collect()
             torch.cuda.empty_cache()
         return model, tokenizer
-    pass
-pass
diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py
index 43eefbbbf..de70f7873 100644
--- a/unsloth/models/gemma2.py
+++ b/unsloth/models/gemma2.py
@@ -21,6 +21,7 @@
     GemmaFixedLinearScalingRotaryEmbedding,
     fast_geglu_inference,
 )
+
 try:
     from transformers.models.gemma2.modeling_gemma2 import (
         Gemma2Attention,
@@ -33,20 +34,20 @@
     )
 except:
     from packaging.version import Version
+
     transformers_version = Version(transformers_version)
     if not transformers_version >= Version("4.42"):
         raise ImportError(
-            f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
-            f"The minimum required version is 4.42.3.\n"\
-            f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
-            f"to obtain the latest transformers build, then restart this session."\
+            f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"
+            f"The minimum required version is 4.42.3.\n"
+            f'Try `pip install --upgrade "transformers>=4.42.3"`\n'
+            f"to obtain the latest transformers build, then restart this session."
         )
-    pass
-pass
 
 from transformers.modeling_attn_mask_utils import (
     _prepare_4d_causal_attention_mask_for_sdpa,
 )
+
 # For Pytorch 2.1.1
 try:
     from transformers.models.gemma2.modeling_gemma2 import (
@@ -54,27 +55,27 @@
         Gemma2FlashAttention2,
     )
 except:
-    Gemma2SdpaAttention   = Gemma2Attention
+    Gemma2SdpaAttention = Gemma2Attention
     Gemma2FlashAttention2 = Gemma2Attention
-pass
 
 if HAS_FLASH_ATTENTION_SOFTCAPPING:
     from flash_attn import flash_attn_func
 
+
 # Logit softcapping
 def Gemma2Attention_fast_forward(
     self,
-    hidden_states:        torch.Tensor,
-    causal_mask:          Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:       Optional[torch.Tensor] = None,
-    position_ids:         Optional[torch.LongTensor] = None,
-    past_key_value:       Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:    bool = False,
-    use_cache:            bool = False,
-    padding_mask:         Optional[torch.LongTensor] = None,
-    *args, **kwargs,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+    padding_mask: Optional[torch.LongTensor] = None,
+    *args,
+    **kwargs,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-
     # Clear inference
     if hasattr(self, "paged_attention"):
         del self.paged_attention_K
@@ -84,18 +85,17 @@ def Gemma2Attention_fast_forward(
         del self.temp_KV
         del self.RH_Q
         del self.attention
-    pass
 
     bsz, q_len, _ = hidden_states.size()
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
-    assert(n_kv_heads * n_groups == n_heads)
+    head_dim = self.head_dim
+    assert n_kv_heads * n_groups == n_heads
 
     Q, K, V = self.apply_qkv(self, hidden_states)
-    Q = Q.view(bsz, q_len, n_heads,    head_dim).transpose(1, 2)
+    Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
     K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
     V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
 
@@ -111,12 +111,10 @@ def Gemma2Attention_fast_forward(
     else:
         cos, sin = self.rotary_emb.get_cached(kv_seq_len, device_index)
         Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
-    pass
 
     if past_key_value is not None:
         K = torch.cat([past_key_value[0], K], dim = 2)
         V = torch.cat([past_key_value[1], V], dim = 2)
-    pass
     past_key_value = (K, V) if use_cache else None
 
     # Only enable if the attention_mask is True
@@ -127,114 +125,139 @@ def Gemma2Attention_fast_forward(
             sw = getattr(self.config, "sliding_window", None)
             sw = kv_seq_len if (sw is None or sw == "null") else sw
             window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
-        pass
 
         # FA uses 1 / sqrt for softmax_scale!
         if not hasattr(self, "_flash_attention_softmax_scale"):
-            self._flash_attention_softmax_scale = 1.0 / (self.config.query_pre_attn_scalar**0.5)
-        pass
+            self._flash_attention_softmax_scale = 1.0 / (
+                self.config.query_pre_attn_scalar**0.5
+            )
 
         Q = Q.transpose(1, 2)
         K = K.transpose(1, 2)
         V = V.transpose(1, 2)
         A = flash_attn_func(
-            Q, K, V,
+            Q,
+            K,
+            V,
             causal = True,
             softcap = self.config.attn_logit_softcapping,
             softmax_scale = self._flash_attention_softmax_scale,
             window_size = window,
         )
-        A = A.reshape(bsz, q_len, n_heads*head_dim)
+        A = A.reshape(bsz, q_len, n_heads * head_dim)
     else:
-        fx = slow_inference_attention_softcapping \
-            if "_flag_for_generation" in kwargs else \
-            slow_attention_softcapping
+        fx = (
+            slow_inference_attention_softcapping
+            if "_flag_for_generation" in kwargs
+            else slow_attention_softcapping
+        )
         A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len)
-    pass
     A = self.apply_o(self, A)
     return A, None, past_key_value
-pass
 
 
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
 def Gemma2DecoderLayer_fast_forward(
     self,
-    hidden_states:        torch.Tensor,
-    causal_mask:          Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:       Optional[torch.Tensor] = None,
-    position_ids:         Optional[torch.LongTensor] = None,
-    past_key_value:       Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:    Optional[bool] = False,
-    use_cache:            Optional[bool] = False,
-    padding_mask:         Optional[torch.LongTensor] = None,
-    *args, **kwargs,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: Optional[bool] = False,
+    use_cache: Optional[bool] = False,
+    padding_mask: Optional[torch.LongTensor] = None,
+    *args,
+    **kwargs,
 ):
-    if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
-        out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
+    if use_cache and hasattr(
+        self, "_flag_for_generation"
+    ):  # past_key_value is not None:
+        out_weight = torch.empty(
+            self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0"
+        )
 
         # Self Attention
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            self.input_layernorm, hidden_states, out_weight
+        )
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states=hidden_states,
-            causal_mask=causal_mask,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_value=past_key_value,
-            output_attentions=output_attentions,
-            use_cache=use_cache,
-            padding_mask=padding_mask,
-            _flag_for_generation=self._flag_for_generation,
-        )
-        hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
+            _flag_for_generation = self._flag_for_generation,
+        )
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            self.post_attention_layernorm, hidden_states, out_weight
+        )
         hidden_states += residual
 
         # Fully Connected
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference_gemma(self. pre_feedforward_layernorm, hidden_states, out_weight)
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            self.pre_feedforward_layernorm, hidden_states, out_weight
+        )
         hidden_states = fast_geglu_inference(self.mlp, hidden_states)
-        hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight)
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            self.post_feedforward_layernorm, hidden_states, out_weight
+        )
         hidden_states += residual
     else:
         residual = hidden_states
-        hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
+        hidden_states = fast_rms_layernorm(
+            self.input_layernorm, hidden_states, gemma = True
+        )
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states=hidden_states,
-            causal_mask=causal_mask,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_value=past_key_value,
-            output_attentions=output_attentions,
-            use_cache=use_cache,
-            padding_mask=padding_mask,
-        )
-        hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
+        )
+        hidden_states = fast_rms_layernorm(
+            self.post_attention_layernorm, hidden_states, gemma = True
+        )
         hidden_states = residual + hidden_states
 
         # Fully Connected
         residual = hidden_states
-        hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True)
+        hidden_states = fast_rms_layernorm(
+            self.pre_feedforward_layernorm, hidden_states, gemma = True
+        )
         hidden_states = self.mlp(hidden_states)
-        hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True)
+        hidden_states = fast_rms_layernorm(
+            self.post_feedforward_layernorm, hidden_states, gemma = True
+        )
         hidden_states = residual + hidden_states
-    pass
 
     outputs = (hidden_states,)
-    if output_attentions: outputs += (self_attn_weights,)
-    if use_cache: outputs += (present_key_value,)
+    if output_attentions:
+        outputs += (self_attn_weights,)
+    if use_cache:
+        outputs += (present_key_value,)
     return outputs
-pass
 
 
 from math import sqrt as math_sqrt
-KV_CACHE_INCREMENT = 256 # KV Cache update size
+
+KV_CACHE_INCREMENT = 256  # KV Cache update size
 torch_nn_functional_softmax = torch.nn.functional.softmax
 torch_matmul = torch.matmul
-torch_tanh   = torch.tanh
+torch_tanh = torch.tanh
+
 
 def Gemma2Attention_fast_forward_inference(
     self,
-    hidden_states:  torch.Tensor,
+    hidden_states: torch.Tensor,
     past_key_value: Optional[Tuple[torch.Tensor]],
     position_ids,
     do_prefill = False,
@@ -246,14 +269,14 @@ def Gemma2Attention_fast_forward_inference(
     K1, V1 = past_key_value
     dtype = Xn.dtype
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
+    head_dim = self.head_dim
     # assert(n_kv_heads * n_groups == n_heads)
 
     hidden_size = self.config.hidden_size
-    attention_size = n_heads*head_dim
+    attention_size = n_heads * head_dim
     seq_len = K1.shape[-2]
     kv_seq_len = seq_len + 1
     device = hidden_states.device
@@ -261,17 +284,27 @@ def Gemma2Attention_fast_forward_inference(
     # Prefill phase
     # if not hasattr(self, "paged_attention"):
     if do_prefill:
-        self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = device)
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
+        self.paged_attention = torch.empty(
+            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
+            dtype = dtype,
+            device = device,
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
         self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
         self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
-        self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device)
-        self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device)
+        self.temp_QA = torch.empty(
+            (2, bsz, 1, attention_size), dtype = dtype, device = device
+        )
+        self.temp_KV = torch.empty(
+            (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
+        )
         self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
         # Only for Gemma2
-        self.temp_O  = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
-        self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device)
+        self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
+        self.attention = torch.empty(
+            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
+        )
 
         # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
         # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
@@ -280,19 +313,28 @@ def Gemma2Attention_fast_forward_inference(
         self.scalar = 1.0 / math_sqrt(self.config.query_pre_attn_scalar)
         # self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads)
         self.half_head_dim = head_dim // 2
-        self.           t =       self.config.attn_logit_softcapping
+        self.t = self.config.attn_logit_softcapping
         self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping
     elif kv_seq_len >= self.paged_attention.shape[0]:
-        self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
-        self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
-    pass
+        self.paged_attention.resize_(
+            (
+                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
+                2,
+                bsz,
+                n_kv_heads,
+                head_dim,
+            )
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
+        self.attention.resize_(
+            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
+        )
 
     Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
     Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
     Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
-    Qn = Qn.view(bsz, 1, n_heads,    head_dim).transpose(1, 2)
+    Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
     Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
     Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
 
@@ -304,16 +346,18 @@ def Gemma2Attention_fast_forward_inference(
     h = self.half_head_dim
 
     RH_Q = self.RH_Q
-    RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
-    RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
-    torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
+    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
+    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
+    torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h])
     Qn *= cos
     Qn.addcmul_(RH_Q, sin)
 
-    RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
-    RH_K[:,:,:,:h] = Kn[:,:,:,h:]
-    RH_K[:,:,:,h:] = Kn[:,:,:,:h]
-    torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
+    RH_K = RH_Q[
+        :, :n_kv_heads, :, :
+    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+    RH_K[:, :, :, :h] = Kn[:, :, :, h:]
+    RH_K[:, :, :, h:] = Kn[:, :, :, :h]
+    torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h])
     Kn *= cos
     Kn.addcmul_(RH_K, sin)
 
@@ -330,34 +374,40 @@ def Gemma2Attention_fast_forward_inference(
     if use_sliding_window and kv_seq_len > sliding_window:
         # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
         slicing_tokens = 1 - sliding_window
-        Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
-        Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
+        Knn = Kn[:, :, slicing_tokens:, :]  # .contiguous()
+        Vnn = Vn[:, :, slicing_tokens:, :]  # .contiguous()
     else:
         Knn, Vnn = Kn, Vn
-    pass
 
     # Grouped query attention
     _, _, cached_len, _ = Knn.shape
     if n_groups != 1:
-        Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
-        Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+        Knn = Knn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
+        Vnn = Vnn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
         Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
         Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
-    pass
     # else:
     #     Knn, Vnn = Knn, Vnn
     # pass
 
     # Attention
     # if bsz == 1:
-    Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
+    Qn *= (
+        self.scalar
+    )  # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
     # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
-    A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
+    A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len])
     # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
 
-    A *= self.reciprocal_t; torch_tanh(A, out = A); A *= self.t;  # Logit softcapping
+    A *= self.reciprocal_t
+    torch_tanh(A, out = A)
+    A *= self.t  # Logit softcapping
 
-    A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
+    A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)  # .to(A.dtype)
     A = torch_matmul(A, Vnn, out = Qn)
     # else:
     #     A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
@@ -366,7 +416,6 @@ def Gemma2Attention_fast_forward_inference(
     A = A.reshape(bsz, 1, attention_size)
     A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
     return A, (Kn, Vn)
-pass
 
 
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
@@ -378,20 +427,29 @@ def Gemma2Model_fast_forward_inference(
     position_ids,
     attention_mask = None,
 ):
-    out_weights = tuple(torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT))
-    input_ids = input_ids[:,:self.max_seq_length]
+    out_weights = tuple(
+        torch.empty_like(
+            self.model.layers[0].input_layernorm.weight,
+            dtype = torch.float32,
+            device = torch.device(x),
+        )
+        for x in range(DEVICE_COUNT)
+    )
+    input_ids = input_ids[:, : self.max_seq_length]
     hidden_states = self.model.embed_tokens(input_ids)
     hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
     # 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
     # 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
-    hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
+    hidden_states *= torch.tensor(
+        math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype
+    )
 
     bsz, q_len, hd = hidden_states.shape
     seq_len = past_key_values[0][0].shape[-2]
     if bsz != 1:
         if HAS_FLASH_ATTENTION_SOFTCAPPING:
             SWA = True
-            GA  = False
+            GA = False
         else:
             SWA = _prepare_4d_causal_attention_mask_for_sdpa(
                 attention_mask,
@@ -406,14 +464,11 @@ def Gemma2Model_fast_forward_inference(
                 hidden_states,
                 seq_len,
             )
-        pass
     else:
         SWA = attention_mask
-        GA  = attention_mask
-    pass
+        GA = attention_mask
     next_decoder_cache = []
     for idx, decoder_layer in enumerate(self.model.layers):
-
         # For pipeline parallelism, we need to move all tensors to the same device
         # note that this movement is once per GPU in PP
         device_index = getattr(decoder_layer, "_per_layer_device_index", 0)
@@ -424,7 +479,9 @@ def Gemma2Model_fast_forward_inference(
         use_sliding_window = idx % 2 == 0
 
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weights[device_index])
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            decoder_layer.input_layernorm, hidden_states, out_weights[device_index]
+        )
         hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
             decoder_layer.self_attn,
             hidden_states = hidden_states,
@@ -434,18 +491,31 @@ def Gemma2Model_fast_forward_inference(
             do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
             use_sliding_window = use_sliding_window,
         )
-        hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weights[device_index])
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            decoder_layer.post_attention_layernorm,
+            hidden_states,
+            out_weights[device_index],
+        )
         hidden_states += residual
 
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weights[device_index])
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            decoder_layer.pre_feedforward_layernorm,
+            hidden_states,
+            out_weights[device_index],
+        )
         hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
-        hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weights[device_index])
+        hidden_states = fast_rms_layernorm_inference_gemma(
+            decoder_layer.post_feedforward_layernorm,
+            hidden_states,
+            out_weights[device_index],
+        )
         hidden_states += residual
 
         next_decoder_cache.append(present_key_value)
-    pass
-    hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weights[device_index])
+    hidden_states = fast_rms_layernorm_inference_gemma(
+        self.model.norm, hidden_states, out_weights[device_index]
+    )
 
     return BaseModelOutputWithPast(
         last_hidden_state = hidden_states,
@@ -453,30 +523,29 @@ def Gemma2Model_fast_forward_inference(
         hidden_states = [],
         attentions = [],
     )
-pass
 
 
 class FastGemma2Model(FastLlamaModel):
-
     @staticmethod
     def pre_patch():
         init_name, function = patch_linear_scaling(
-            model_name         = "gemma2",
-            rope_module        = GemmaFixedRotaryEmbedding,
+            model_name = "gemma2",
+            rope_module = GemmaFixedRotaryEmbedding,
             scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
-            attention_module   = Gemma2Attention,
+            attention_module = Gemma2Attention,
         )
         if init_name is not None:
             exec(function, globals())
-            Gemma2Attention.__init__  = eval(init_name)
-        pass
-        Gemma2Attention      .forward = Gemma2Attention_fast_forward
-        Gemma2SdpaAttention  .forward = Gemma2Attention_fast_forward
+            Gemma2Attention.__init__ = eval(init_name)
+        Gemma2Attention.forward = Gemma2Attention_fast_forward
+        Gemma2SdpaAttention.forward = Gemma2Attention_fast_forward
         Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward
-        Gemma2DecoderLayer   .forward = Gemma2DecoderLayer_fast_forward
-        Gemma2Model          .forward = LlamaModel_fast_forward
-        Gemma2ForCausalLM    .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference)
-        PeftModelForCausalLM .forward = PeftModel_fast_forward
+        Gemma2DecoderLayer.forward = Gemma2DecoderLayer_fast_forward
+        Gemma2Model.forward = LlamaModel_fast_forward
+        Gemma2ForCausalLM.forward = CausalLM_fast_forward(
+            Gemma2Model_fast_forward_inference
+        )
+        PeftModelForCausalLM.forward = PeftModel_fast_forward
         fix_prepare_inputs_for_generation(Gemma2ForCausalLM)
 
         # Solves https://github.com/unslothai/unsloth/issues/168
@@ -485,15 +554,18 @@ def pre_patch():
         # https://github.com/huggingface/transformers/pull/27931
         # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
         import transformers.models.gemma2.modeling_gemma2
-        transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding
-        return
-    pass
 
+        transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = (
+            GemmaFixedRotaryEmbedding
+        )
+        return
 
     @staticmethod
     def post_patch(model, tokenizer):
         # Gemma does not downcast RoPE
-        model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = False)
+        model, tokenizer = patch_model_and_tokenizer(
+            model, tokenizer, downcast_rope = False
+        )
 
         # Add 1 to weight
         # return output * (1 + self.weight)
@@ -507,7 +579,6 @@ def post_patch(model, tokenizer):
                 param.requires_grad_(True)
             else:
                 param.requires_grad_(False)
-        pass
 
         # Patch RMS Layernorm
         for name, module in model.named_modules():
@@ -518,14 +589,14 @@ def post_patch(model, tokenizer):
                 # Leave + 1 to Triton kernel itself
                 # module.weight += 1.0 # return output * (1 + self.weight)
                 if not hasattr(module, "variance_epsilon"):
-                    module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
-        pass
+                    module.variance_epsilon = (
+                        module.eps
+                    )  # Gemma doesn't use variance_epsilon
 
         # Clear deleted GPU items
         import gc
+
         for _ in range(3):
             gc.collect()
             torch.cuda.empty_cache()
         return model, tokenizer
-    pass
-pass
diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py
index e07575900..0a816e399 100644
--- a/unsloth/models/granite.py
+++ b/unsloth/models/granite.py
@@ -24,6 +24,7 @@
 from .mistral import *
 from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
 from peft.tuners.lora import Linear4bit as Peft_Linear4bit
+
 try:
     from transformers.models.granite.modeling_granite import (
         GraniteAttention,
@@ -37,13 +38,11 @@
     transformers_version = Version(transformers_version)
     if not transformers_version >= Version("4.45.0"):
         raise ImportError(
-            f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
-            f"The minimum required version is 4.42.3.\n"\
-            f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
-            f"to obtain the latest transformers build, then restart this session."\
+            f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"
+            f"The minimum required version is 4.42.3.\n"
+            f'Try `pip install --upgrade "transformers>=4.42.3"`\n'
+            f"to obtain the latest transformers build, then restart this session."
         )
-    pass
-pass
 
 from transformers.modeling_attn_mask_utils import (
     _prepare_4d_causal_attention_mask_for_sdpa,
@@ -56,24 +55,24 @@
         GraniteFlashAttention2,
     )
 except:
-    GraniteSdpaAttention   = GraniteAttention
+    GraniteSdpaAttention = GraniteAttention
     GraniteFlashAttention2 = GraniteAttention
-pass
+
 
 def GraniteAttention_fast_forward(
     self,
-    hidden_states:        torch.Tensor,
-    causal_mask:          Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:       Optional[torch.Tensor] = None,
-    position_ids:         Optional[torch.LongTensor] = None,
-    past_key_value:       Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:    bool = False,
-    use_cache:            bool = False,
-    padding_mask:         Optional[torch.LongTensor] = None,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+    padding_mask: Optional[torch.LongTensor] = None,
     position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-    *args, **kwargs,
+    *args,
+    **kwargs,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-
     # Clear inference
     if hasattr(self, "paged_attention"):
         del self.paged_attention_K
@@ -83,19 +82,18 @@ def GraniteAttention_fast_forward(
         del self.temp_KV
         del self.RH_Q
         del self.attention
-    pass
 
     bsz, q_len, _ = hidden_states.size()
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
-    dropout_p  = self.config.attention_dropout if self.training else 0
-    assert(n_kv_heads * n_groups == n_heads)
+    head_dim = self.head_dim
+    dropout_p = self.config.attention_dropout if self.training else 0
+    assert n_kv_heads * n_groups == n_heads
 
     Q, K, V = self.apply_qkv(self, hidden_states)
-    Q = Q.view(bsz, q_len, n_heads,    head_dim).transpose(1, 2)
+    Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
     K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
     V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
 
@@ -113,11 +111,10 @@ def GraniteAttention_fast_forward(
     if past_key_value is not None:
         K = torch.cat([past_key_value[0], K], dim = 2)
         V = torch.cat([past_key_value[1], V], dim = 2)
-    pass
     past_key_value = (K, V) if use_cache else None
 
     # Attention module
-    if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None):
+    if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None:
         # Xformers memory efficient attention
         Q = Q.transpose(1, 2)
         K = K.transpose(1, 2)
@@ -126,8 +123,8 @@ def GraniteAttention_fast_forward(
         Q_M = bsz * q_len
 
         # Group query attention
-        K = K  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
-        V = V  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
+        K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+        V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
         K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
         V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
         if hidden_states.requires_grad:
@@ -136,9 +133,10 @@ def GraniteAttention_fast_forward(
         else:
             # Xformers does support the forward pass though
             Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
-        pass
 
-        A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling, p=dropout_p)
+        A = xformers_attention(
+            Q, K, V, attn_bias = causal_mask, scale = self.scaling, p = dropout_p
+        )
         A = A.view(bsz, q_len, n_heads, head_dim)
 
     elif HAS_FLASH_ATTENTION and attention_mask is None:
@@ -146,7 +144,15 @@ def GraniteAttention_fast_forward(
         K = K.transpose(1, 2)
         V = V.transpose(1, 2)
         window = (kv_seq_len, kv_seq_len)
-        A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling, dropout_p=dropout_p)
+        A = flash_attn_func(
+            Q,
+            K,
+            V,
+            causal = True,
+            window_size = window,
+            softmax_scale = self.scaling,
+            dropout_p = dropout_p,
+        )
     else:
         # Grouped query attention
         # if n_groups != 1:
@@ -160,70 +166,84 @@ def GraniteAttention_fast_forward(
         Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
         # Needs (batch_size, n_heads, seq_len, head_dim)
         # is_casual and attention_mask must not be both set!
-        A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False, dropout_p=dropout_p)
+        A = scaled_dot_product_attention(
+            Q,
+            K,
+            V,
+            attn_mask = attention_mask,
+            scale = self.scaling,
+            is_causal = False,
+            dropout_p = dropout_p,
+        )
         # Go back to (batch_size, seq_len, n_heads, head_dim)
         A = A.transpose(1, 2).contiguous()
-    pass
 
-    attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
+    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
     attn_output = self.apply_o(self, attn_output)
     attn_weights = None
     return attn_output, attn_weights, past_key_value
-pass
 
 
 def GraniteDecoderLayer_fast_forward(
     self,
-    hidden_states:        torch.Tensor,
-    causal_mask:          Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:       Optional[torch.Tensor] = None,
-    position_ids:         Optional[torch.LongTensor] = None,
-    past_key_value:       Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:    Optional[bool] = False,
-    use_cache:            Optional[bool] = False,
-    padding_mask:         Optional[torch.LongTensor] = None,
-    position_embeddings:  Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-    *args, **kwargs,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: Optional[bool] = False,
+    use_cache: Optional[bool] = False,
+    padding_mask: Optional[torch.LongTensor] = None,
+    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+    *args,
+    **kwargs,
 ):
-    residual_multiplier = \
-        self.residual_multiplier \
-        if hasattr(self, "residual_multiplier") else \
-        self.config.residual_multiplier
+    residual_multiplier = (
+        self.residual_multiplier
+        if hasattr(self, "residual_multiplier")
+        else self.config.residual_multiplier
+    )
 
-    if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
+    if use_cache and hasattr(
+        self, "_flag_for_generation"
+    ):  # past_key_value is not None:
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
+        hidden_states = fast_rms_layernorm_inference(
+            self.input_layernorm, hidden_states
+        )
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states=hidden_states,
-            causal_mask=causal_mask,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_value=past_key_value,
-            output_attentions=output_attentions,
-            use_cache=use_cache,
-            padding_mask=padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
             position_embeddings = position_embeddings,
-            _flag_for_generation=self._flag_for_generation,
+            _flag_for_generation = self._flag_for_generation,
         )
         hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
 
         # Fully Connected
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
+        hidden_states = fast_rms_layernorm_inference(
+            self.post_attention_layernorm, hidden_states
+        )
         hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
         hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
     else:
         residual = hidden_states
         hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states=hidden_states,
-            causal_mask=causal_mask,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_value=past_key_value,
-            output_attentions=output_attentions,
-            use_cache=use_cache,
-            padding_mask=padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
             position_embeddings = position_embeddings,
         )
         hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
@@ -233,47 +253,50 @@ def GraniteDecoderLayer_fast_forward(
         hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
         hidden_states = self.mlp(hidden_states)
         hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
-    pass
 
     outputs = (hidden_states,)
-    if output_attentions: outputs += (self_attn_weights,)
-    if use_cache: outputs += (present_key_value,)
+    if output_attentions:
+        outputs += (self_attn_weights,)
+    if use_cache:
+        outputs += (present_key_value,)
     return outputs
-pass
 
 
 from math import sqrt as math_sqrt
-KV_CACHE_INCREMENT = 256 # KV Cache update size
+
+KV_CACHE_INCREMENT = 256  # KV Cache update size
 torch_nn_functional_softmax = torch.nn.functional.softmax
 torch_matmul = torch.matmul
-torch_tanh   = torch.tanh
+torch_tanh = torch.tanh
+
 
 def GraniteAttention_fast_forward_inference(
     self,
-    hidden_states:  torch.Tensor,
+    hidden_states: torch.Tensor,
     past_key_value: Optional[Tuple[torch.Tensor]],
     position_ids,
     do_prefill = False,
     attention_mask = None,
     use_sliding_window = False,
-    position_embeddings : Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 ):
-
-    assert position_embeddings is not None, f"Granite model requires position embeddings to be specified"
+    assert (
+        position_embeddings is not None
+    ), f"Granite model requires position embeddings to be specified"
 
     Xn = hidden_states
     bsz, _, hd = hidden_states.size()
     K1, V1 = past_key_value
     dtype = Xn.dtype
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
+    head_dim = self.head_dim
     # assert(n_kv_heads * n_groups == n_heads)
 
     hidden_size = self.config.hidden_size
-    attention_size = n_heads*head_dim
+    attention_size = n_heads * head_dim
     seq_len = K1.shape[-2]
     kv_seq_len = seq_len + 1
     device = hidden_states.device
@@ -281,31 +304,49 @@ def GraniteAttention_fast_forward_inference(
     # Prefill phase
     # if not hasattr(self, "paged_attention"):
     if do_prefill:
-        self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = device)
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
+        self.paged_attention = torch.empty(
+            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
+            dtype = dtype,
+            device = device,
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
         self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
         self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
-        self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device)
-        self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device)
+        self.temp_QA = torch.empty(
+            (2, bsz, 1, attention_size), dtype = dtype, device = device
+        )
+        self.temp_KV = torch.empty(
+            (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
+        )
         self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
         # Only for Gemma2
-        self.temp_O  = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
-        self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device)
-
+        self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
+        self.attention = torch.empty(
+            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
+        )
 
         self.half_head_dim = head_dim // 2
     elif kv_seq_len >= self.paged_attention.shape[0]:
-        self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
-        self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
-    pass
+        self.paged_attention.resize_(
+            (
+                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
+                2,
+                bsz,
+                n_kv_heads,
+                head_dim,
+            )
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
+        self.attention.resize_(
+            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
+        )
 
     Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
     Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
     Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
-    Qn = Qn.view(bsz, 1, n_heads,    head_dim).transpose(1, 2)
+    Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
     Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
     Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
 
@@ -316,16 +357,18 @@ def GraniteAttention_fast_forward_inference(
     h = self.half_head_dim
 
     RH_Q = self.RH_Q
-    RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
-    RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
-    torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
+    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
+    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
+    torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h])
     Qn *= cos
     Qn.addcmul_(RH_Q, sin)
 
-    RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
-    RH_K[:,:,:,:h] = Kn[:,:,:,h:]
-    RH_K[:,:,:,h:] = Kn[:,:,:,:h]
-    torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
+    RH_K = RH_Q[
+        :, :n_kv_heads, :, :
+    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+    RH_K[:, :, :, :h] = Kn[:, :, :, h:]
+    RH_K[:, :, :, h:] = Kn[:, :, :, :h]
+    torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h])
     Kn *= cos
     Kn.addcmul_(RH_K, sin)
 
@@ -340,21 +383,24 @@ def GraniteAttention_fast_forward_inference(
     # Grouped query attention
     _, _, cached_len, _ = Kn.shape
     if n_groups != 1:
-        Kn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
-        Vn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+        Kn = Kn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
+        Vn = Vn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
         Kn = Kn.reshape(bsz, n_heads, cached_len, head_dim)
         Vn = Vn.reshape(bsz, n_heads, cached_len, head_dim)
-    pass
     # else:
     #     Kn, Vn = Kn, Vn
     # pass
 
     Qn *= self.scaling
-    A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
+    A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:, :, :, :cached_len])
 
     # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
 
-    A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
+    A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)  # .to(A.dtype)
     A = torch_matmul(A, Vn, out = Qn)
     # else:
     #     A = scaled_dot_product_attention(Qn, Kn, Vn, attn_mask = attention_mask, is_causal = False)
@@ -363,7 +409,6 @@ def GraniteAttention_fast_forward_inference(
     A = A.reshape(bsz, 1, attention_size)
     A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
     return A, (Kn, Vn)
-pass
 
 
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
@@ -375,14 +420,15 @@ def GraniteModel_fast_forward_inference(
     position_ids,
     attention_mask = None,
 ):
-    input_ids = input_ids[:,:self.max_seq_length]
+    input_ids = input_ids[:, : self.max_seq_length]
     hidden_states = self.model.embed_tokens(input_ids)
     hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
     hidden_states *= self.model.embedding_multiplier
-    residual_multiplier = \
-        self.residual_multiplier \
-        if hasattr(self, "residual_multiplier") else \
-        self.config.residual_multiplier
+    residual_multiplier = (
+        self.residual_multiplier
+        if hasattr(self, "residual_multiplier")
+        else self.config.residual_multiplier
+    )
 
     bsz, q_len, hd = hidden_states.shape
     seq_len = past_key_values[0][0].shape[-2]
@@ -395,9 +441,10 @@ def GraniteModel_fast_forward_inference(
         )
     else:
         attention_mask = None
-    pass
 
-    position_embeddings = self.model.rotary_emb.get_cached(self.max_seq_length, hidden_states.device.index)
+    position_embeddings = self.model.rotary_emb.get_cached(
+        self.max_seq_length, hidden_states.device.index
+    )
 
     next_decoder_cache = []
     for idx, decoder_layer in enumerate(self.model.layers):
@@ -407,7 +454,9 @@ def GraniteModel_fast_forward_inference(
         )
 
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states)
+        hidden_states = fast_rms_layernorm_inference(
+            decoder_layer.input_layernorm, hidden_states
+        )
         hidden_states, present_key_value = GraniteAttention_fast_forward_inference(
             decoder_layer.self_attn,
             hidden_states = hidden_states,
@@ -421,12 +470,13 @@ def GraniteModel_fast_forward_inference(
         hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
 
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
+        hidden_states = fast_rms_layernorm_inference(
+            decoder_layer.post_attention_layernorm, hidden_states
+        )
         hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
         hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
 
         next_decoder_cache.append(present_key_value)
-    pass
     hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states)
 
     return BaseModelOutputWithPast(
@@ -435,12 +485,13 @@ def GraniteModel_fast_forward_inference(
         hidden_states = [],
         attentions = [],
     )
-pass
+
 
 class GraniteRotaryEmbedding(LlamaRotaryEmbedding):
     def __init__(self, config):
         super().__init__(config = config)
 
+
 def patched_init(original_init):
     def new_init(self, *args, **kwargs):
         # we can use self.residual_multiplier arg in GraniteDecoderLayer_fast_forward as mentioned here
@@ -451,64 +502,70 @@ def new_init(self, *args, **kwargs):
         if config is not None:
             self.config = config
         original_init(self, *args, **kwargs)
+
     return new_init
 
-class FastGraniteModel(FastLlamaModel):
 
+class FastGraniteModel(FastLlamaModel):
     @staticmethod
     def pre_patch():
         init_name, function = patch_linear_scaling(
-            model_name         = "granite",
-            rope_module        = GraniteRotaryEmbedding,
+            model_name = "granite",
+            rope_module = GraniteRotaryEmbedding,
             scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
-            attention_module   = GraniteAttention,
+            attention_module = GraniteAttention,
         )
         if init_name is not None:
             exec(function, globals())
-            GraniteAttention.__init__  = eval(init_name)
-        pass
-        GraniteAttention      .forward  = GraniteAttention_fast_forward
-        GraniteSdpaAttention  .forward  = GraniteAttention_fast_forward
-        GraniteFlashAttention2.forward  = GraniteAttention_fast_forward
-        GraniteDecoderLayer   .forward  = GraniteDecoderLayer_fast_forward
-        GraniteModel          .forward  = LlamaModel_fast_forward
-        GraniteForCausalLM    .forward  = CausalLM_fast_forward(GraniteModel_fast_forward_inference)
-        GraniteForCausalLM    .__init__ = patched_init(GraniteForCausalLM.__init__)
-        PeftModelForCausalLM .forward = PeftModel_fast_forward
+            GraniteAttention.__init__ = eval(init_name)
+        GraniteAttention.forward = GraniteAttention_fast_forward
+        GraniteSdpaAttention.forward = GraniteAttention_fast_forward
+        GraniteFlashAttention2.forward = GraniteAttention_fast_forward
+        GraniteDecoderLayer.forward = GraniteDecoderLayer_fast_forward
+        GraniteModel.forward = LlamaModel_fast_forward
+        GraniteForCausalLM.forward = CausalLM_fast_forward(
+            GraniteModel_fast_forward_inference
+        )
+        GraniteForCausalLM.__init__ = patched_init(GraniteForCausalLM.__init__)
+        PeftModelForCausalLM.forward = PeftModel_fast_forward
         fix_prepare_inputs_for_generation(GraniteForCausalLM)
 
         import transformers.models.granite.modeling_granite
-        transformers.models.granite.modeling_granite.GraniteRotaryEmbedding = GraniteRotaryEmbedding
 
-        return
-    pass
+        transformers.models.granite.modeling_granite.GraniteRotaryEmbedding = (
+            GraniteRotaryEmbedding
+        )
 
+        return
 
     @staticmethod
     def post_patch(model, tokenizer):
-
         # Torch.compile fails on embedding matrix??
         # Workaround randomnly fixes it for torch versions < 2.2
-        model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
-        model.config.update({"unsloth_version" : __version__})
+        model.model.embed_tokens = torch.nn.Embedding.from_pretrained(
+            model.model.embed_tokens.weight
+        )
+        model.config.update({"unsloth_version": __version__})
 
         # We also do this for the lm_head
         lm_head = torch.nn.Linear(1, 1, bias = None)
         del lm_head.weight
         lm_head.weight = model.lm_head.weight
-        lm_head.in_features  = lm_head.weight.shape[1]
+        lm_head.in_features = lm_head.weight.shape[1]
         lm_head.out_features = lm_head.weight.shape[0]
         model.lm_head = lm_head
 
         # Granite has tied weights! This means lm_head == embed_tokens
-        if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
+        if (
+            model.model.embed_tokens.weight.data_ptr()
+            != model.lm_head.weight.data_ptr()
+        ):
             lm_head = torch.nn.Linear(1, 1, bias = None)
             del lm_head.weight
             lm_head.weight = model.model.embed_tokens.weight
-            lm_head.in_features  = lm_head.weight.shape[1]
+            lm_head.in_features = lm_head.weight.shape[1]
             lm_head.out_features = lm_head.weight.shape[0]
             model.lm_head = lm_head
-        pass
 
         # Also patch all dtypes - BnB seems to not allocate the correct type?
         # BnB default dtype seems to be float16!
@@ -521,35 +578,30 @@ def post_patch(model, tokenizer):
 
                 if type(quant_state) is list:
                     # BnB seems to have float16 as default!
-                    module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
+                    module.weight.quant_state[2] = (
+                        correct_dtype  # Cast to correct dtype
+                    )
                 else:
                     # https://github.com/TimDettmers/bitsandbytes/pull/763/files
                     quant_state.dtype = correct_dtype
-                pass
-            pass
             # Downcast RoPE embedding to correct data type
-            if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")):
-
-                if hasattr(module, "cos_cached") and \
-                    (module.cos_cached.dtype != correct_dtype):
-
+            if name.endswith("rotary_emb") or hasattr(module, "cos_cached"):
+                if hasattr(module, "cos_cached") and (
+                    module.cos_cached.dtype != correct_dtype
+                ):
                     module.cos_cached = module.cos_cached.to(correct_dtype)
                     module.sin_cached = module.sin_cached.to(correct_dtype)
 
-                elif hasattr(module, "short_cos_cached") and \
-                    (module.short_cos_cached.dtype != correct_dtype):
-
+                elif hasattr(module, "short_cos_cached") and (
+                    module.short_cos_cached.dtype != correct_dtype
+                ):
                     module.short_cos_cached = module.short_cos_cached.to(correct_dtype)
                     module.short_sin_cached = module.short_sin_cached.to(correct_dtype)
-                pass
-            pass
-        pass
 
         # Clear deleted GPU items
         import gc
+
         for _ in range(3):
             gc.collect()
             torch.cuda.empty_cache()
         return model, tokenizer
-    pass
-pass
diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py
index 6abd49222..aa3167ce0 100644
--- a/unsloth/models/llama.py
+++ b/unsloth/models/llama.py
@@ -25,7 +25,11 @@
 from torch.nn.functional import scaled_dot_product_attention
 from transformers import __version__ as transformers_version
 from unsloth_zoo.utils import Version, _get_dtype
-from unsloth_zoo.hf_utils import dtype_from_config, add_dtype_kwargs, fix_lora_auto_mapping
+from unsloth_zoo.hf_utils import (
+    dtype_from_config,
+    add_dtype_kwargs,
+    fix_lora_auto_mapping,
+)
 from unsloth_zoo.peft_utils import SKIP_QUANTIZATION_MODULES
 from ..device_type import (
     is_hip,
@@ -54,6 +58,7 @@
 )
 from ..kernels import *
 from ..tokenizer_utils import *
+
 if HAS_FLASH_ATTENTION:
     from flash_attn import flash_attn_func
 from .vision import FastBaseModel
@@ -73,11 +78,16 @@
         LlamaFlashAttention2,
     )
 except:
-    LlamaSdpaAttention   = LlamaAttention
+    LlamaSdpaAttention = LlamaAttention
     LlamaFlashAttention2 = LlamaAttention
-pass
 
-from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, BitsAndBytesConfig, AutoConfig
+from transformers import (
+    AutoTokenizer,
+    AutoModelForCausalLM,
+    AutoModelForSequenceClassification,
+    BitsAndBytesConfig,
+    AutoConfig,
+)
 from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
 from transformers import set_seed as transformers_set_seed
 from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
@@ -85,15 +95,18 @@
 from ..save import patch_saving_functions
 import re, os, inspect, math, sys
 import types
+
 try:
     from huggingface_hub.utils import get_token
 except:
     # Old HF Hub versions <= 0.0.25
     from huggingface_hub.utils._token import get_token
-pass
 from triton import __version__ as triton_version
+
 HAS_XFORMERS = xformers is not None
-BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None
+BlockDiagonalCausalMask = (
+    xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None
+)
 
 if DEVICE_TYPE == "xpu":
     clean_gpu_cache = torch.xpu.empty_cache
@@ -101,30 +114,35 @@
 else:
     clean_gpu_cache = torch.cuda.empty_cache
     get_current_device = torch.cuda.current_device
-pass
+
 
 def original_apply_qkv(self, X):
     Q = self.q_proj(X)
     K = self.k_proj(X)
     V = self.v_proj(X)
     return Q, K, V
-pass
 
 
 def original_apply_o(self, X):
     O = self.o_proj(X)
     return O
-pass
+
 
 from math import sqrt as math_sqrt
-KV_CACHE_INCREMENT = 512 # KV Cache update size
+
+KV_CACHE_INCREMENT = 512  # KV Cache update size
 torch_nn_functional_softmax = torch.nn.functional.softmax
 # SDPA has GQA internally
 SDPA_HAS_GQA = "enable_gqa" in scaled_dot_product_attention.__doc__
 
 
 # Fix new HF's inference code
-def _fast_prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs,):
+def _fast_prepare_inputs_for_generation(
+    self,
+    input_ids,
+    attention_mask = None,
+    **kwargs,
+):
     past_key_values = kwargs.get("past_key_values", None)
     if past_key_values is not None:
         # Check for uninitialized DynamicCache
@@ -132,19 +150,25 @@ def _fast_prepare_inputs_for_generation(self, input_ids, attention_mask=None, **
             past_key_values = None
             kwargs["past_key_values"] = None
         # New since 4.56
-        elif hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0:
+        elif (
+            hasattr(past_key_values, "get_seq_length")
+            and past_key_values.get_seq_length() == 0
+        ):
             past_key_values = None
             kwargs["past_key_values"] = None
         else:
             bs, cache_length = input_ids.shape
-            input_ids = input_ids[:,[-1]]
+            input_ids = input_ids[:, [-1]]
 
             # Get to the base model
             base_model = self
-            if hasattr(base_model, 'base_model_prefix'):
+            if hasattr(base_model, "base_model_prefix"):
                 base_model = getattr(base_model, base_model.base_model_prefix)
 
-            if hasattr(base_model, "_prepare_4d_causal_attention_mask_with_cache_position"):
+            if hasattr(
+                base_model, "_prepare_4d_causal_attention_mask_with_cache_position"
+            ):
+
                 def needs_device_kw(fn) -> bool:
                     try:
                         sig = inspect.signature(inspect.unwrap(fn))
@@ -157,23 +181,31 @@ def needs_device_kw(fn) -> bool:
                     "sequence_length": 1,
                     "target_length": cache_length,
                     "dtype": self.dtype,
-                    "cache_position": torch.arange(cache_length, cache_length+1, device=input_ids.device),
+                    "cache_position": torch.arange(
+                        cache_length, cache_length + 1, device = input_ids.device
+                    ),
                     "batch_size": bs,
                     "config": self.config,
                     "past_key_values": past_key_values,
                 }
                 try:
-                    if needs_device_kw(base_model._prepare_4d_causal_attention_mask_with_cache_position):
+                    if needs_device_kw(
+                        base_model._prepare_4d_causal_attention_mask_with_cache_position
+                    ):
                         kwargs["device"] = input_ids.device
                 except:
-                    print(f"Unsloth: Could not inspect signature of {base_model._prepare_4d_causal_attention_mask_with_cache_position}")
+                    print(
+                        f"Unsloth: Could not inspect signature of {base_model._prepare_4d_causal_attention_mask_with_cache_position}"
+                    )
 
-                attention_mask = base_model._prepare_4d_causal_attention_mask_with_cache_position(
-                    attention_mask,
-                    **kwargs,
+                attention_mask = (
+                    base_model._prepare_4d_causal_attention_mask_with_cache_position(
+                        attention_mask,
+                        **kwargs,
+                    )
                 )
             else:
-                attention_mask = attention_mask[:,[-1]]
+                attention_mask = attention_mask[:, [-1]]
                 if transformers_version <= Version("4.52.4"):
                     logger.warning_once(
                         f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
@@ -184,67 +216,71 @@ def needs_device_kw(fn) -> bool:
 
     if "cache_position" in kwargs:
         kwargs["position_ids"] = kwargs["cache_position"]
-    return { "input_ids" : input_ids, "attention_mask": attention_mask, **kwargs, }
-pass
+    return {
+        "input_ids": input_ids,
+        "attention_mask": attention_mask,
+        **kwargs,
+    }
 
 
 def fix_prepare_inputs_for_generation(module):
     # Fix prepare_inputs_for_generation
     if hasattr(module, "prepare_inputs_for_generation"):
         module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation
-    pass
-pass
+
 
 torch_matmul = torch.matmul
+
+
 def LlamaAttention_fast_forward_inference(
     self,
-    hidden_states:  torch.Tensor,
+    hidden_states: torch.Tensor,
     past_key_value: Optional[Tuple[torch.Tensor]],
     position_ids,
     do_prefill = False,
     attention_mask = None,
 ):
     """
-        https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
-        Fast inference using KV cache.
-        QK^T can be computed in 4 chunks
+    https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
+    Fast inference using KV cache.
+    QK^T can be computed in 4 chunks
 
-        [Q, q] @ [K, k].T where q, k are the new tokens.
-        [QK^T, Qk^T]
-        [qK^T, qk^T]
+    [Q, q] @ [K, k].T where q, k are the new tokens.
+    [QK^T, Qk^T]
+    [qK^T, qk^T]
 
-        Since the attention mask wipes Qk^T, we just get
-        [QK^T,    0]
-        [qK^T, qk^T]
+    Since the attention mask wipes Qk^T, we just get
+    [QK^T,    0]
+    [qK^T, qk^T]
 
-        Since softmax is row-wise, we get
-        softmax([QK^T,    0])
-        softmax([qK^T, qk^T])
+    Since softmax is row-wise, we get
+    softmax([QK^T,    0])
+    softmax([qK^T, qk^T])
 
-        We then multiply by   [V]
-                              [v]
-        softmax([QK^T,    0]) [softmax(QK^T)V] *
-        softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
+    We then multiply by   [V]
+                          [v]
+    softmax([QK^T,    0]) [softmax(QK^T)V] *
+    softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
 
-        But notice * [softmax(QK^T)V] is just the last attention.
-        We just need to compute the last final row.
+    But notice * [softmax(QK^T)V] is just the last attention.
+    We just need to compute the last final row.
 
-        This means we can pass in a row of Q, but we need to
-        remember K and V, which are called the KV cache.
+    This means we can pass in a row of Q, but we need to
+    remember K and V, which are called the KV cache.
     """
     Xn = hidden_states
     bsz, _, hd = hidden_states.size()
     K1, V1 = past_key_value
     dtype = Xn.dtype
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
+    head_dim = self.head_dim
     # assert(n_kv_heads * n_groups == n_heads)
 
     hidden_size = self.config.hidden_size
-    attention_size = n_heads*head_dim
+    attention_size = n_heads * head_dim
     seq_len = K1.shape[-2]
     kv_seq_len = seq_len + 1
 
@@ -252,36 +288,54 @@ def LlamaAttention_fast_forward_inference(
     # if not hasattr(self, "paged_attention"):
     device = hidden_states.device
     if do_prefill:
-        self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = device)
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
+        self.paged_attention = torch.empty(
+            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
+            dtype = dtype,
+            device = device,
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
         self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
         self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
-        self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device)
-        self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device)
+        self.temp_QA = torch.empty(
+            (2, bsz, 1, attention_size), dtype = dtype, device = device
+        )
+        self.temp_KV = torch.empty(
+            (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
+        )
         self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
 
         # Mistral Nemo 12b has weird dimensions
         if attention_size != hidden_size:
             self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
         else:
-            self.temp_O = self.temp_QA[1][:,:,:hidden_size]
-        pass
+            self.temp_O = self.temp_QA[1][:, :, :hidden_size]
 
-        self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device)
+        self.attention = torch.empty(
+            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
+        )
         self.scalar = 1.0 / math_sqrt(self.head_dim)
         self.half_head_dim = head_dim // 2
     elif kv_seq_len >= self.paged_attention.shape[0]:
-        self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
-        self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
-    pass
+        self.paged_attention.resize_(
+            (
+                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
+                2,
+                bsz,
+                n_kv_heads,
+                head_dim,
+            )
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
+        self.attention.resize_(
+            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
+        )
 
     Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
     Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
     Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
-    Qn = Qn.view(bsz, 1, n_heads,    head_dim).transpose(1, 2)
+    Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
     Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
     Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
 
@@ -297,16 +351,18 @@ def LlamaAttention_fast_forward_inference(
     h = self.half_head_dim
 
     RH_Q = self.RH_Q
-    RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
-    RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
-    RH_Q[:,:,:,:h].neg_() # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
+    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
+    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
+    RH_Q[:, :, :, :h].neg_()  # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
     Qn *= cos
     Qn.addcmul_(RH_Q, sin)
 
-    RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
-    RH_K[:,:,:,:h] = Kn[:,:,:,h:]
-    RH_K[:,:,:,h:] = Kn[:,:,:,:h]
-    RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
+    RH_K = RH_Q[
+        :, :n_kv_heads, :, :
+    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+    RH_K[:, :, :, :h] = Kn[:, :, :, h:]
+    RH_K[:, :, :, h:] = Kn[:, :, :, :h]
+    RH_K[:, :, :, :h].neg_()  # torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
     Kn *= cos
     Kn.addcmul_(RH_K, sin)
 
@@ -323,20 +379,22 @@ def LlamaAttention_fast_forward_inference(
     if sliding_window is not None and kv_seq_len > sliding_window:
         # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
         slicing_tokens = 1 - sliding_window
-        Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
-        Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
+        Knn = Kn[:, :, slicing_tokens:, :]  # .contiguous()
+        Vnn = Vn[:, :, slicing_tokens:, :]  # .contiguous()
     else:
         Knn, Vnn = Kn, Vn
-    pass
 
     # Grouped query attention
     _, _, cached_len, _ = Knn.shape
     if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1:
-        Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
-        Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+        Knn = Knn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
+        Vnn = Vnn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
         Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
         Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
-    pass
     # else:
     #     Knn, Vnn = Knn, Vnn
     # pass
@@ -350,27 +408,42 @@ def LlamaAttention_fast_forward_inference(
         is_causal = False
     # Attention
     if bsz == 1:
-        Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
+        Qn *= self.scalar  # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
         # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
-        A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
+        A = torch_matmul(
+            Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
+        )
         # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
-        A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
+        A[:] = torch_nn_functional_softmax(
+            A, dim = -1, dtype = torch.float32
+        )  # .to(A.dtype)
         A = torch_matmul(A, Vnn, out = Qn)
     else:
         if SDPA_HAS_GQA:
-            A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal, enable_gqa = True)
+            A = scaled_dot_product_attention(
+                Qn,
+                Knn,
+                Vnn,
+                attn_mask = attention_mask,
+                is_causal = is_causal,
+                enable_gqa = True,
+            )
         else:
-            A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal)
-    pass
+            A = scaled_dot_product_attention(
+                Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal
+            )
     A = A.transpose(1, 2)
     A = A.reshape(bsz, 1, attention_size)
     A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
     return A, (Kn, Vn)
-pass
 
 
 torch_nn_functional_silu = torch.nn.functional.silu
-def fast_swiglu_inference(self, X, temp_gate = None, temp_up = None, gate_multiplier = None, down_multiplier = None):
+
+
+def fast_swiglu_inference(
+    self, X, temp_gate = None, temp_up = None, gate_multiplier = None, down_multiplier = None
+):
     # gate = self.gate_proj(X)
     # up   = self.up_proj(X)
     bsz, _, hd = X.shape
@@ -382,22 +455,24 @@ def fast_swiglu_inference(self, X, temp_gate = None, temp_up = None, gate_multip
     if gate_multiplier is not None:
         gate *= gate_multiplier
 
-    up   = fast_linear_forward(self.  up_proj, X, out = temp_up)
+    up = fast_linear_forward(self.up_proj, X, out = temp_up)
 
     gate = torch_nn_functional_silu(gate, inplace = True)
     gate *= up
 
     # X = self.down_proj(gate)
-    down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
+    down = fast_linear_forward(self.down_proj, gate, out = up[:, :, :hd])
 
     if down_multiplier is not None:
         down *= down_multiplier
 
     return down
-pass
+
 
 torch_square = torch.square
-torch_mean   = torch.mean
+torch_mean = torch.mean
+
+
 def fast_rms_layernorm_inference(self, X, XX = None, XX2 = None, variance = None):
     old_dtype = X.dtype
     if XX is None:
@@ -406,16 +481,16 @@ def fast_rms_layernorm_inference(self, X, XX = None, XX2 = None, variance = None
     else:
         XX.copy_(X)
         torch_mean(torch_square(XX, out = XX2), -1, keepdim = True, out = variance)
-    pass
     variance += self.variance_epsilon
     XX *= variance.rsqrt_()
 
-    if XX is None: X = XX.to(old_dtype)
-    else: X.copy_(XX)
+    if XX is None:
+        X = XX.to(old_dtype)
+    else:
+        X.copy_(XX)
 
     X *= self.weight
     return X
-pass
 
 
 def fast_rms_layernorm_inference_gemma(self, X, out_weight = None):
@@ -429,11 +504,9 @@ def fast_rms_layernorm_inference_gemma(self, X, out_weight = None):
     else:
         out_weight[:] = self.weight
         out_weight += 1.0
-    pass
 
     XX *= out_weight
     return XX.to(X.dtype)
-pass
 
 
 # Normal layernorm with mean removal
@@ -443,28 +516,29 @@ def fast_layernorm_compiled(layernorm, X):
     X = X.float()
     mean = X.mean(-1, keepdim = True)
     Xbar = X - mean
-    X = Xbar * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + \
-        layernorm.variance_epsilon) * \
-        layernorm.weight.float()
+    X = (
+        Xbar
+        * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + layernorm.variance_epsilon)
+        * layernorm.weight.float()
+    )
     return X.to(old_dtype)
-pass
 
 
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320
 def LlamaAttention_fast_forward(
     self,
-    hidden_states:       torch.Tensor,
-    causal_mask:         Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:      Optional[torch.Tensor] = None,
-    position_ids:        Optional[torch.LongTensor] = None,
-    past_key_value:      Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:   bool = False,
-    use_cache:           bool = False,
-    padding_mask:        Optional[torch.LongTensor] = None,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+    padding_mask: Optional[torch.LongTensor] = None,
     position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-    *args, **kwargs,
+    *args,
+    **kwargs,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-
     # Clear inference
     if hasattr(self, "paged_attention"):
         del self.paged_attention_K
@@ -474,18 +548,17 @@ def LlamaAttention_fast_forward(
         del self.temp_KV
         del self.RH_Q
         del self.attention
-    pass
 
     bsz, q_len, _ = hidden_states.size()
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
-    assert(n_kv_heads * n_groups == n_heads)
+    head_dim = self.head_dim
+    assert n_kv_heads * n_groups == n_heads
 
     Q, K, V = self.apply_qkv(self, hidden_states)
-    Q = Q.view(bsz, q_len, n_heads,    head_dim).transpose(1, 2)
+    Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
     K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
     V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
 
@@ -517,11 +590,10 @@ def LlamaAttention_fast_forward(
     if past_key_value is not None:
         K = torch.cat([past_key_value[0], K], dim = 2)
         V = torch.cat([past_key_value[1], V], dim = 2)
-    pass
     past_key_value = (K, V) if use_cache else None
 
     # Attention module
-    if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None):
+    if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None:
         # Xformers memory efficient attention
         # Also has Flash Attention v2 dispatching
         Q = Q.transpose(1, 2)
@@ -530,8 +602,8 @@ def LlamaAttention_fast_forward(
 
         # Group query attention
         if n_groups != 1:
-            K = K  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
-            V = V  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
+            K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+            V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
             K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
             V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
             if hidden_states.requires_grad:
@@ -539,7 +611,6 @@ def LlamaAttention_fast_forward(
                 V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
             else:
                 Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
-        pass
         A = xformers_attention(Q, K, V, attn_bias = causal_mask)
         A = A.view(bsz, q_len, n_heads, head_dim)
 
@@ -560,46 +631,56 @@ def LlamaAttention_fast_forward(
         if SDPA_HAS_GQA:
             # Needs (batch_size, n_heads, seq_len, head_dim)
             # is_casual and attention_mask must not be both set!
-            A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = is_causal, enable_gqa = n_groups != 1)
+            A = scaled_dot_product_attention(
+                Q,
+                K,
+                V,
+                attn_mask = attention_mask,
+                is_causal = is_causal,
+                enable_gqa = n_groups != 1,
+            )
             # Go back to (batch_size, seq_len, n_heads, head_dim)
-            A = A.transpose(1, 2)#.contiguous()
+            A = A.transpose(1, 2)  # .contiguous()
         else:
             if n_groups != 1:
-                K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
-                V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+                K = K[:, :, None, :, :].expand(
+                    bsz, n_kv_heads, n_groups, kv_seq_len, head_dim
+                )
+                V = V[:, :, None, :, :].expand(
+                    bsz, n_kv_heads, n_groups, kv_seq_len, head_dim
+                )
                 K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
                 V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
-            pass
             # Must be contiguous or else results are False!
             # https://github.com/pytorch/pytorch/issues/112577
             Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
             # Needs (batch_size, n_heads, seq_len, head_dim)
             # is_casual and attention_mask must not be both set!
-            A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = is_causal)
+            A = scaled_dot_product_attention(
+                Q, K, V, attn_mask = attention_mask, is_causal = is_causal
+            )
             # Go back to (batch_size, seq_len, n_heads, head_dim)
             A = A.transpose(1, 2).contiguous()
-        pass
-    pass
-    attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
+    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
     attn_output = self.apply_o(self, attn_output)
     attn_weights = None
     return attn_output, attn_weights, past_key_value
-pass
 
 
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
 def LlamaDecoderLayer_fast_forward(
     self,
-    hidden_states:       torch.Tensor,
-    causal_mask          = None,
-    attention_mask:      Optional[torch.Tensor] = None,
-    position_ids:        Optional[torch.LongTensor] = None,
-    past_key_value:      Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:   Optional[bool] = False,
-    use_cache:           Optional[bool] = False,
-    padding_mask:        Optional[torch.LongTensor] = None,
+    hidden_states: torch.Tensor,
+    causal_mask = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: Optional[bool] = False,
+    use_cache: Optional[bool] = False,
+    padding_mask: Optional[torch.LongTensor] = None,
     position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-    *args, **kwargs,
+    *args,
+    **kwargs,
 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
     """
     Args:
@@ -616,37 +697,41 @@ def LlamaDecoderLayer_fast_forward(
     """
     if use_cache and hasattr(self, "_flag_for_generation"):
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
+        hidden_states = fast_rms_layernorm_inference(
+            self.input_layernorm, hidden_states
+        )
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states       = hidden_states,
-            causal_mask         = causal_mask,
-            attention_mask      = attention_mask,
-            position_ids        = position_ids,
-            past_key_value      = past_key_value,
-            output_attentions   = output_attentions,
-            use_cache           = use_cache,
-            padding_mask        = padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
             position_embeddings = position_embeddings,
         )
         hidden_states += residual
 
         # Fully Connected
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
+        hidden_states = fast_rms_layernorm_inference(
+            self.post_attention_layernorm, hidden_states
+        )
         hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
         hidden_states += residual
     else:
         residual = hidden_states
         hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states       = hidden_states,
-            causal_mask         = causal_mask,
-            attention_mask      = attention_mask,
-            position_ids        = position_ids,
-            past_key_value      = past_key_value,
-            output_attentions   = output_attentions,
-            use_cache           = use_cache,
-            padding_mask        = padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
             position_embeddings = position_embeddings,
         )
         hidden_states = residual + hidden_states
@@ -656,13 +741,13 @@ def LlamaDecoderLayer_fast_forward(
         hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
         hidden_states = self.mlp(hidden_states)
         hidden_states = residual + hidden_states
-    pass
 
     outputs = (hidden_states,)
-    if output_attentions: outputs += (self_attn_weights,)
-    if use_cache: outputs += (present_key_value,)
+    if output_attentions:
+        outputs += (self_attn_weights,)
+    if use_cache:
+        outputs += (present_key_value,)
     return outputs
-pass
 
 
 # https://github.com/unslothai/unsloth/issues/404#issuecomment-2323473452
@@ -675,40 +760,53 @@ def LlamaDecoderLayer_fast_forward(
     torch.bfloat16: torch.bfloat16,
 }
 
+
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
 def LlamaModel_fast_forward(
     self,
-    input_ids:            torch.LongTensor,
-    causal_mask:          Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:       Optional[torch.Tensor] = None,
-    position_ids:         Optional[torch.LongTensor] = None,
-    past_key_values:      Optional[List[torch.FloatTensor]] = None,
-    inputs_embeds:        Optional[torch.FloatTensor] = None,
-    use_cache:            Optional[bool] = None,
-    output_attentions:    Optional[bool] = None,
+    input_ids: torch.LongTensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_values: Optional[List[torch.FloatTensor]] = None,
+    inputs_embeds: Optional[torch.FloatTensor] = None,
+    use_cache: Optional[bool] = None,
+    output_attentions: Optional[bool] = None,
     output_hidden_states: Optional[bool] = None,
-    return_dict:          Optional[bool] = None,
-    *args, **kwargs,
+    return_dict: Optional[bool] = None,
+    *args,
+    **kwargs,
 ) -> Union[Tuple, BaseModelOutputWithPast]:
-
-    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
-    assert(output_attentions is False)
+    output_attentions = (
+        output_attentions
+        if output_attentions is not None
+        else self.config.output_attentions
+    )
+    assert output_attentions is False
     output_hidden_states = (
-        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        output_hidden_states
+        if output_hidden_states is not None
+        else self.config.output_hidden_states
     )
     use_cache = use_cache if use_cache is not None else self.config.use_cache
 
-    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+    return_dict = (
+        return_dict if return_dict is not None else self.config.use_return_dict
+    )
 
     # retrieve input_ids and inputs_embeds
     if input_ids is not None and inputs_embeds is not None:
-        raise ValueError("Unsloth: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        raise ValueError(
+            "Unsloth: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
+        )
     elif input_ids is not None:
         batch_size, seq_length = input_ids.shape
     elif inputs_embeds is not None:
         batch_size, seq_length, _ = inputs_embeds.shape
     else:
-        raise ValueError("Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds")
+        raise ValueError(
+            "Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds"
+        )
 
     seq_length_with_past = seq_length
 
@@ -717,41 +815,37 @@ def LlamaModel_fast_forward(
         if seq_length > self.max_seq_length:
             shape = input_ids.shape if input_ids is not None else inputs_embeds.shape
             logger.warning_once(
-                f"Unsloth: Input IDs of shape {shape} with length {seq_length} > the model's max sequence length of {self.max_seq_length}.\n"\
+                f"Unsloth: Input IDs of shape {shape} with length {seq_length} > the model's max sequence length of {self.max_seq_length}.\n"
                 "We shall truncate it ourselves. It's imperative if you correct this issue first."
             )
         if input_ids is not None:
-            input_ids = input_ids[:,:self.max_seq_length]
+            input_ids = input_ids[:, : self.max_seq_length]
         elif inputs_embeds is not None:
-            inputs_embeds = inputs_embeds[:,:self.max_seq_length,:]
-        pass
-    pass
+            inputs_embeds = inputs_embeds[:, : self.max_seq_length, :]
 
     past_key_values_length = 0
 
     if past_key_values is not None:
         past_key_values_length = past_key_values[0][0].shape[2]
         seq_length_with_past = seq_length_with_past + past_key_values_length
-    pass
 
     # We already handle KV cache position_ids ourselves.
-    if False:#(past_key_values_length != 0):
+    if False:  # (past_key_values_length != 0):
         position_ids = torch.arange(
-            past_key_values_length, seq_length + past_key_values_length,
-            dtype  = torch.int32,
+            past_key_values_length,
+            seq_length + past_key_values_length,
+            dtype = torch.int32,
             device = f"{DEVICE_TYPE_TORCH}:0",
         )
         position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
     elif position_ids is not None:
-        position_ids = position_ids.view(-1, seq_length).to(torch.int32)#.long()
+        position_ids = position_ids.view(-1, seq_length).to(torch.int32)  # .long()
     else:
         position_ids = None
-    pass
 
     if position_ids is not None:
         if position_ids.shape[0] != batch_size:
             position_ids = position_ids.repeat((batch_size, 1))
-    pass
 
     # Embed positions
     if inputs_embeds is None:
@@ -760,9 +854,9 @@ def LlamaModel_fast_forward(
     inputs_embeds = inputs_embeds.to(_get_dtype(dtype_from_config(self.config)))
 
     # Normalized from Gemma
-    IS_GEMMA   = self.config.model_type.startswith("gemma")
-    IS_GEMMA2  = self.config.model_type.startswith("gemma2")
-    IS_COHERE  = self.config.model_type.startswith("cohere")
+    IS_GEMMA = self.config.model_type.startswith("gemma")
+    IS_GEMMA2 = self.config.model_type.startswith("gemma2")
+    IS_COHERE = self.config.model_type.startswith("cohere")
     IS_GRANITE = self.config.model_type.startswith("granite")
     IS_FALCON_H1 = self.config.model_type.startswith("falcon_h1")
 
@@ -773,7 +867,9 @@ def LlamaModel_fast_forward(
         # inputs_embeds *= math_sqrt(self.config.hidden_size)
         # Ie 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
         # &  2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
-        normalizer = torch.tensor(math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype)
+        normalizer = torch.tensor(
+            math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype
+        )
 
         if train_embed_tokens:
             # Careful we must not do an inplace op!
@@ -785,17 +881,20 @@ def LlamaModel_fast_forward(
                 inputs_requires_grad = True
             elif inputs_requires_grad:
                 inputs_embeds.requires_grad_(False)
-            pass
             inputs_embeds *= normalizer
             # inputs_embeds *= math_sqrt(self.config.hidden_size)
-            if inputs_requires_grad: inputs_embeds.requires_grad_(True)
-        pass
-    pass
+            if inputs_requires_grad:
+                inputs_embeds.requires_grad_(True)
 
     # Fix up attention mask by setting elements to 0
     # Specifically for DPO
-    if getattr(self, "_has_no_labels", False) is True and (attention_mask is not None) and (past_key_values is None) and \
-        (not train_embed_tokens) and self.training:
+    if (
+        getattr(self, "_has_no_labels", False) is True
+        and (attention_mask is not None)
+        and (past_key_values is None)
+        and (not train_embed_tokens)
+        and self.training
+    ):
         # Careful for inference the attention_mask is size (1, kv_seq_len)
         # Whilst the input_embeds is size (1, 1, 4096)
         inputs_requires_grad = inputs_embeds.requires_grad
@@ -804,11 +903,10 @@ def LlamaModel_fast_forward(
             inputs_requires_grad = True
         elif inputs_requires_grad:
             inputs_embeds.requires_grad_(False)
-        pass
-        attention_mask = attention_mask[:,:self.max_seq_length] # Must resize!
+        attention_mask = attention_mask[:, : self.max_seq_length]  # Must resize!
         inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)
-        if inputs_requires_grad: inputs_embeds.requires_grad_(True)
-    pass
+        if inputs_requires_grad:
+            inputs_embeds.requires_grad_(True)
 
     # Ignore attention_mask
     if attention_mask is None:
@@ -832,10 +930,9 @@ def LlamaModel_fast_forward(
         # Must NOT convert to bool - weirdly this causes stuff to error out!
         # if attention_mask is not None:
         #     attention_mask = attention_mask.to(torch.bool)
-    pass
 
     hidden_states = inputs_embeds
-    if IS_GRANITE or IS_FALCON_H1: #granite has embedding multiplier
+    if IS_GRANITE or IS_FALCON_H1:  # granite has embedding multiplier
         hidden_states = self.config.embedding_multiplier * hidden_states
 
     if past_key_values is None and self.training:
@@ -845,7 +942,6 @@ def LlamaModel_fast_forward(
         #         "Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`"
         #     )
         #     use_cache = False
-    pass
 
     # decoder layers
     all_hidden_states = () if output_hidden_states else None
@@ -857,23 +953,21 @@ def LlamaModel_fast_forward(
         boundaries = self._gradient_checkpointing_boundaries
     else:
         boundaries = None
-    pass
 
     # Check checkpointing method
     gradient_checkpointing = False
 
-    if (self.gradient_checkpointing and self.training and not use_cache):
+    if self.gradient_checkpointing and self.training and not use_cache:
         gradient_checkpointing = True
-    pass
 
     # Gemma2 has alternating SWA and global attn
-    use_static_mask  = True
+    use_static_mask = True
     dynamic_SWA_mask = None
-    dynamic_GA_mask  = None
+    dynamic_GA_mask = None
     if IS_GEMMA2:
         if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
             self.SWA_mask = True
-            self.GA_mask  = False
+            self.GA_mask = False
         elif attention_mask is not None:
             # Fixes https://github.com/unslothai/unsloth/issues/853
             # Unsloth needs a 2D mask, not a [2, 1, n, n] mask!
@@ -900,60 +994,95 @@ def LlamaModel_fast_forward(
         elif not hasattr(self, "SWA_mask"):
             if HAS_FLEX_ATTENTION:
                 # Use Flex Attention instead!
-                self.SWA_mask = create_flex_attention_sliding_window_mask(self.max_seq_length, self.config.sliding_window)
-                self.GA_mask  = create_flex_attention_causal_mask(self.max_seq_length)
+                self.SWA_mask = create_flex_attention_sliding_window_mask(
+                    self.max_seq_length, self.config.sliding_window
+                )
+                self.GA_mask = create_flex_attention_causal_mask(self.max_seq_length)
             else:
-                n = self.max_seq_length # self.config.max_position_embeddings
+                n = self.max_seq_length  # self.config.max_position_embeddings
                 # masked_fill is making stuff slower!
                 # self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)
                 # self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window)
                 from transformers.modeling_attn_mask_utils import AttentionMaskConverter
-                self.SWA_mask = AttentionMaskConverter(
-                    is_causal = True,
-                    sliding_window = self.config.sliding_window,
-                )\
-                    .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = DEVICE_TYPE_TORCH,)\
-                    .squeeze(0).squeeze(0)
-
-                self.GA_mask = AttentionMaskConverter(
-                    is_causal = True,
-                )\
-                    .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = DEVICE_TYPE_TORCH,)\
-                    .squeeze(0).squeeze(0)
-            pass
-        pass
-    pass
 
-    if (IS_ATTENTION_REFACTOR and (hasattr(self, "rotary_emb") or not hasattr(self.layers[0].self_attn, "rotary_emb"))) or IS_GRANITE:
+                self.SWA_mask = (
+                    AttentionMaskConverter(
+                        is_causal = True,
+                        sliding_window = self.config.sliding_window,
+                    )
+                    .to_causal_4d(
+                        1,
+                        n,
+                        n,
+                        dtype = inputs_embeds.dtype,
+                        device = DEVICE_TYPE_TORCH,
+                    )
+                    .squeeze(0)
+                    .squeeze(0)
+                )
+
+                self.GA_mask = (
+                    AttentionMaskConverter(
+                        is_causal = True,
+                    )
+                    .to_causal_4d(
+                        1,
+                        n,
+                        n,
+                        dtype = inputs_embeds.dtype,
+                        device = DEVICE_TYPE_TORCH,
+                    )
+                    .squeeze(0)
+                    .squeeze(0)
+                )
+
+    if (
+        IS_ATTENTION_REFACTOR
+        and (
+            hasattr(self, "rotary_emb")
+            or not hasattr(self.layers[0].self_attn, "rotary_emb")
+        )
+    ) or IS_GRANITE:
         # Transformers main has made it mandatory to pass position_embeddings
         # https://github.com/huggingface/transformers/pull/34858
         # Also, transformers 4.45.0 supports granite but with the attention refactor (it always had the refactor)
         # unsloth's check for granite too has "version >= 4.45.0 (rightly so)".
         # so let granite always use the attention refactor implementation.
-        position_embeddings = self.rotary_emb.get_cached(self.config.max_position_embeddings, hidden_states.device.index)
+        position_embeddings = self.rotary_emb.get_cached(
+            self.config.max_position_embeddings, hidden_states.device.index
+        )
     else:
         position_embeddings = None
 
     # Go through every layer!
     for idx, decoder_layer in enumerate(self.layers):
-
-        if output_hidden_states: all_hidden_states += (hidden_states,)
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
         past_key_value = past_key_values[idx] if past_key_values is not None else None
 
         mask = causal_mask
         if IS_GEMMA2:
-            if (idx % 2 == 0):
+            if idx % 2 == 0:
                 mask = self.SWA_mask if use_static_mask else dynamic_SWA_mask
             else:
-                mask = self. GA_mask if use_static_mask else dynamic_GA_mask
-        pass
+                mask = self.GA_mask if use_static_mask else dynamic_GA_mask
+
+        if gradient_checkpointing and not isinstance(
+            decoder_layer, GradientCheckpointingLayer
+        ):
 
-        if gradient_checkpointing and not isinstance(decoder_layer, GradientCheckpointingLayer):
             def create_custom_forward(module):
                 def custom_forward(*inputs):
-                    return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings)
+                    return module(
+                        *inputs,
+                        past_key_value,
+                        output_attentions,
+                        padding_mask = padding_mask,
+                        position_embeddings = position_embeddings,
+                    )
+
                 return custom_forward
-            pass
+
             layer_outputs = torch.utils.checkpoint.checkpoint(
                 create_custom_forward(decoder_layer),
                 hidden_states,
@@ -968,54 +1097,66 @@ def custom_forward(*inputs):
         else:
             layer_outputs = decoder_layer(
                 hidden_states,
-                causal_mask         = mask,
-                attention_mask      = attention_mask,
-                position_ids        = position_ids,
-                past_key_value      = past_key_value,
-                output_attentions   = output_attentions,
-                use_cache           = use_cache,
-                padding_mask        = padding_mask,
+                causal_mask = mask,
+                attention_mask = attention_mask,
+                position_ids = position_ids,
+                past_key_value = past_key_value,
+                output_attentions = output_attentions,
+                use_cache = use_cache,
+                padding_mask = padding_mask,
                 position_embeddings = position_embeddings,
             )
             hidden_states = layer_outputs[0]
-        pass
 
-        if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
-        if output_attentions: all_self_attns += (layer_outputs[1],)
-    pass
+        if use_cache:
+            next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+        if output_attentions:
+            all_self_attns += (layer_outputs[1],)
 
     # Final layernorm
     if use_cache:
         if IS_FALCON_H1:
-            hidden_states = fast_rms_layernorm_inference(self.final_layernorm, hidden_states)
+            hidden_states = fast_rms_layernorm_inference(
+                self.final_layernorm, hidden_states
+            )
         else:
-            hidden_states = \
-                (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\
-                (self.norm, hidden_states)
+            hidden_states = (
+                fast_rms_layernorm_inference_gemma
+                if IS_GEMMA
+                else fast_rms_layernorm_inference
+            )(self.norm, hidden_states)
     elif IS_COHERE:
         hidden_states = self.norm(hidden_states)
     elif IS_FALCON_H1:
-        hidden_states = fast_rms_layernorm(self.final_layernorm, hidden_states, gemma = IS_GEMMA)
+        hidden_states = fast_rms_layernorm(
+            self.final_layernorm, hidden_states, gemma = IS_GEMMA
+        )
     else:
         hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)
-    pass
 
-    if output_hidden_states: all_hidden_states += (hidden_states,)
+    if output_hidden_states:
+        all_hidden_states += (hidden_states,)
     next_cache = next_decoder_cache if use_cache else None
 
     if not return_dict:
-        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+        return tuple(
+            v
+            for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+            if v is not None
+        )
     return BaseModelOutputWithPast(
-        last_hidden_state=hidden_states,
-        past_key_values=next_cache,
-        hidden_states=all_hidden_states,
-        attentions=all_self_attns,
+        last_hidden_state = hidden_states,
+        past_key_values = next_cache,
+        hidden_states = all_hidden_states,
+        attentions = all_self_attns,
     )
-pass
 
 
 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
-def _LlamaModel_fast_forward_inference(attention_fast_forward_inference=LlamaAttention_fast_forward_inference, mlp_fast_forward_inference=fast_swiglu_inference):
+def _LlamaModel_fast_forward_inference(
+    attention_fast_forward_inference = LlamaAttention_fast_forward_inference,
+    mlp_fast_forward_inference = fast_swiglu_inference,
+):
     # This makes the attention and MLP customisable.
     # Now for models like qwen3 or cohere which use custom attention operations, we can use this function
     def LlamaModel_fast_forward_inference_custom(
@@ -1025,7 +1166,7 @@ def LlamaModel_fast_forward_inference_custom(
         position_ids,
         attention_mask = None,
     ):
-        input_ids = input_ids[:,:self.max_seq_length]
+        input_ids = input_ids[:, : self.max_seq_length]
         bsz, q_len = input_ids.shape
         hd = self.config.hidden_size
         mlp_size = self.config.intermediate_size
@@ -1033,14 +1174,25 @@ def LlamaModel_fast_forward_inference_custom(
         X = self.model.embed_tokens(input_ids)
         X = X.to(_get_dtype(dtype_from_config(self.config)))
         bsz, q_len, hd = X.shape
-        assert(q_len == 1)
+        assert q_len == 1
         # Get saved buffers to reduce memory movement
-        residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0")
-        _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0")
+        residual = torch.empty(
+            (bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
+        )
+        _XX = torch.empty(
+            (2, bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
+        )
         XX, XX2 = _XX[0], _XX[1]
-        variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0")
-        temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE_TORCH}:0")
-        temp_gates, temp_ups = tuple(temp_mlp[0].to(torch.device(x)) for x in range(DEVICE_COUNT)), tuple(temp_mlp[1].to(torch.device(x)) for x in range(DEVICE_COUNT))
+        variance = torch.empty(
+            (bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
+        )
+        temp_mlp = torch.empty(
+            (2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE_TORCH}:0"
+        )
+        temp_gates, temp_ups = (
+            tuple(temp_mlp[0].to(torch.device(x)) for x in range(DEVICE_COUNT)),
+            tuple(temp_mlp[1].to(torch.device(x)) for x in range(DEVICE_COUNT)),
+        )
 
         seq_len = past_key_values[0][0].shape[-2]
         if bsz != 1:
@@ -1053,7 +1205,6 @@ def LlamaModel_fast_forward_inference_custom(
             )
         else:
             attention_mask = None
-        pass
 
         next_decoder_cache = []
 
@@ -1062,7 +1213,7 @@ def LlamaModel_fast_forward_inference_custom(
             X, residual, position_ids = move_to_device(
                 device_index, X, residual, position_ids
             )
-            residual.copy_(X) # residual = X
+            residual.copy_(X)  # residual = X
             X = fast_rms_layernorm_inference(
                 decoder_layer.input_layernorm,
                 X,
@@ -1080,7 +1231,7 @@ def LlamaModel_fast_forward_inference_custom(
             )
             X += residual
 
-            residual.copy_(X) # residual = X
+            residual.copy_(X)  # residual = X
             X = fast_rms_layernorm_inference(
                 decoder_layer.post_attention_layernorm,
                 X,
@@ -1097,7 +1248,6 @@ def LlamaModel_fast_forward_inference_custom(
             X += residual
 
             next_decoder_cache.append(present_key_value)
-        pass
         X = fast_rms_layernorm_inference(
             self.model.norm,
             X,
@@ -1112,12 +1262,14 @@ def LlamaModel_fast_forward_inference_custom(
             hidden_states = [],
             attentions = [],
         )
-    pass
+
     return LlamaModel_fast_forward_inference_custom
 
+
 # For ensuring backwards compatibility, we create LlamaModel_fast_forward_inference that is consumed by other models
 LlamaModel_fast_forward_inference = _LlamaModel_fast_forward_inference()
 
+
 def CausalLM_fast_forward(fast_forward_inference):
     def _CausalLM_fast_forward(
         self,
@@ -1134,7 +1286,8 @@ def _CausalLM_fast_forward(
         return_dict: Optional[bool] = None,
         num_logits_to_keep: Optional[int] = 0,
         logits_to_keep: Optional[int] = 0,
-        *args, **kwargs,
+        *args,
+        **kwargs,
     ) -> Union[Tuple, CausalLMOutputWithPast]:
         if past_key_values is not None:
             outputs = fast_forward_inference(
@@ -1145,13 +1298,23 @@ def _CausalLM_fast_forward(
                 attention_mask = attention_mask,
             )
         else:
-            causal_mask = xformers.attn_bias.LowerTriangularMask() if HAS_XFORMERS else None
+            causal_mask = (
+                xformers.attn_bias.LowerTriangularMask() if HAS_XFORMERS else None
+            )
 
-            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+            output_attentions = (
+                output_attentions
+                if output_attentions is not None
+                else self.config.output_attentions
+            )
             output_hidden_states = (
-                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+                output_hidden_states
+                if output_hidden_states is not None
+                else self.config.output_hidden_states
+            )
+            return_dict = (
+                return_dict if return_dict is not None else self.config.use_return_dict
             )
-            return_dict = return_dict if return_dict is not None else self.config.use_return_dict
             # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
             self.model._has_no_labels = labels is None
             outputs = self.model(
@@ -1166,7 +1329,6 @@ def _CausalLM_fast_forward(
                 output_hidden_states = output_hidden_states,
                 return_dict = return_dict,
             )
-        pass
         hidden_states = outputs[0]
 
         bsz, q_len, hd = hidden_states.shape
@@ -1174,13 +1336,14 @@ def _CausalLM_fast_forward(
         lm_head_device = lm_head.device
 
         logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
-        logit_scaling     = getattr(self.config, "logit_scale", 0)
+        logit_scaling = getattr(self.config, "logit_scale", 0)
         dtype = lm_head.dtype
         num_logits_to_keep = max(num_logits_to_keep, logits_to_keep)
 
         # Move items to same device as lm_head
         hidden_states = hidden_states.to(lm_head_device)
-        if labels is not None: labels = labels.to(lm_head_device)
+        if labels is not None:
+            labels = labels.to(lm_head_device)
 
         # Output last hidden states without logits if asked
         if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1":
@@ -1191,9 +1354,8 @@ def _CausalLM_fast_forward(
                 logits = hidden_states,
                 past_key_values = outputs.past_key_values,
                 hidden_states = outputs.hidden_states,
-                attentions=  outputs.attentions,
+                attentions = outputs.attentions,
             )
-        pass
 
         if bsz == 1 and q_len == 1:
             logits = torch.mv(lm_head, hidden_states.ravel().to(dtype))
@@ -1209,7 +1371,8 @@ def _CausalLM_fast_forward(
 
             if not RETURN_LOGITS and labels is not None:
                 n_items = kwargs.get("num_items_in_batch", None)
-                if n_items is None: n_items = kwargs.get("n_items", None)
+                if n_items is None:
+                    n_items = kwargs.get("n_items", None)
 
                 if self.config.model_type == "falcon_h1":
                     hidden_states = hidden_states * self.config.lm_head_multiplier
@@ -1224,17 +1387,17 @@ def _CausalLM_fast_forward(
                 #     logit_softcapping  = logit_softcapping,
                 # )
                 loss = unsloth_fused_ce_loss(
-                    trainer              = None,
-                    hidden_states        = hidden_states,
-                    lm_head_weight       = lm_head,
-                    lm_head_bias         = None,
-                    labels               = labels,
-                    mask                 = None,
-                    n_items              = n_items,
-                    scaling              = getattr(self, "accelerator_scaler", None),
-                    target_gb            = None,
-                    torch_compile        = True,
-                    logit_softcapping    = logit_softcapping,
+                    trainer = None,
+                    hidden_states = hidden_states,
+                    lm_head_weight = lm_head,
+                    lm_head_bias = None,
+                    labels = labels,
+                    mask = None,
+                    n_items = n_items,
+                    scaling = getattr(self, "accelerator_scaler", None),
+                    target_gb = None,
+                    torch_compile = True,
+                    logit_softcapping = logit_softcapping,
                 )
                 if not return_dict:
                     output = (logits,) + outputs[1:]
@@ -1243,19 +1406,17 @@ def _CausalLM_fast_forward(
                 output = CausalLMOutputWithPast(
                     loss = loss,
                     logits = EMPTY_LOGITS,
-                    past_key_values=  outputs.past_key_values,
+                    past_key_values = outputs.past_key_values,
                     hidden_states = outputs.hidden_states,
                     attentions = outputs.attentions,
                 )
                 return output
-            pass
             logits = self.lm_head(hidden_states.to(dtype))
-        pass
 
         logits = logits.to(_get_dtype(dtype_from_config(self.config)))
         loss = None
         logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
-        logit_scaling     = getattr(self.config, "logit_scale", 0)
+        logit_scaling = getattr(self.config, "logit_scale", 0)
         if self.config.model_type == "granite":
             # granite uses logit_scaling as key and they divide by the scale unlike cohere
             # notice that for granite, logits_scale is 16 and for cohere it is 0.125 (aka 1/8) in their respective configs
@@ -1276,13 +1437,14 @@ def _CausalLM_fast_forward(
             shift_labels[..., -1] = -100
             # shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
             n_items = kwargs.get("num_items_in_batch", None)
-            if n_items is None: n_items = kwargs.get("n_items", None)
+            if n_items is None:
+                n_items = kwargs.get("n_items", None)
             loss = fast_cross_entropy_loss(
                 logits = shift_logits,
                 labels = shift_labels,
                 logit_softcapping = logit_softcapping,
-                logit_scaling     = logit_scaling,
-                n_items           = n_items,
+                logit_scaling = logit_scaling,
+                n_items = n_items,
             )
         else:
             if logit_scaling != 0:
@@ -1290,20 +1452,15 @@ def _CausalLM_fast_forward(
                     logits = logit_scaling * logits
                 else:
                     logits *= logit_scaling
-                pass
-            pass
             if logit_softcapping != 0:
                 if logits.requires_grad:
                     logits = (1.0 / logit_softcapping) * logits
                     logits = torch.tanh(logits)
                     logits = logit_softcapping * logits
                 else:
-                    logits *= (1.0 / logit_softcapping)
+                    logits *= 1.0 / logit_softcapping
                     torch.tanh(logits, out = logits)
                     logits *= logit_softcapping
-                pass
-            pass
-        pass
 
         if not return_dict:
             output = (logits,) + outputs[1:]
@@ -1313,11 +1470,10 @@ def _CausalLM_fast_forward(
             logits = logits,
             past_key_values = outputs.past_key_values,
             hidden_states = outputs.hidden_states,
-            attentions=  outputs.attentions,
+            attentions = outputs.attentions,
         )
-    pass
+
     return _CausalLM_fast_forward
-pass
 
 
 @torch._disable_dynamo
@@ -1362,7 +1518,6 @@ def PeftModel_fast_forward(
             logits_to_keep = logits_to_keep,
             **kwargs,
         )
-pass
 
 
 # Solves https://github.com/unslothai/unsloth/issues/168
@@ -1374,197 +1529,257 @@ class LlamaRotaryEmbedding(torch.nn.Module):
     # Fixes https://github.com/huggingface/transformers/pull/28837
     # https://github.com/microsoft/DeepSpeed/issues/4932
     # The precision of RoPE buffers is not correct, so we cast to int64.
-    def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
-        config = None, # [TODO] Hack to pass in config - need to remove later
+    def __init__(
+        self,
+        dim = None,
+        max_position_embeddings = 2048,
+        base = 10000,
+        device = None,
+        config = None,  # [TODO] Hack to pass in config - need to remove later
     ):
         super().__init__()
         if config is not None:
             # [TODO] Hack to pass in config - need to remove later
             base = config.rope_theta
-            partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+            partial_rotary_factor = (
+                config.partial_rotary_factor
+                if hasattr(config, "partial_rotary_factor")
+                else 1.0
+            )
             dim = getattr(config, "head_dim", None)
-            if dim is None: dim = int((config.hidden_size // config.num_attention_heads))
+            if dim is None:
+                dim = int((config.hidden_size // config.num_attention_heads))
             device = DEVICE_TYPE_TORCH
             max_position_embeddings = config.max_position_embeddings
-        pass
 
         self.dim = dim
         self.max_position_embeddings = max_position_embeddings
         self.base = base
         # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
         self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
-        self.multi_gpu_cos_cached = [None]*DEVICE_COUNT
-        self.multi_gpu_sin_cached = [None]*DEVICE_COUNT
+        self.multi_gpu_cos_cached = [None] * DEVICE_COUNT
+        self.multi_gpu_sin_cached = [None] * DEVICE_COUNT
 
         # Build here to make `torch.jit.trace` work.
         for device_idx in range(DEVICE_COUNT):
-            self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype())
+            self._set_cos_sin_cache(
+                seq_len = self.current_rope_size,
+                device = torch.device(device_idx),
+                dtype = torch.get_default_dtype(),
+            )
 
         # dummy so that patch_utils doesn't fail for now
-        self.cos_cached = torch.empty(1, device=get_current_device(), dtype=torch.get_default_dtype())
-        self.sin_cached = torch.empty(1, device=get_current_device(), dtype=torch.get_default_dtype())
-    pass
+        self.cos_cached = torch.empty(
+            1, device = get_current_device(), dtype = torch.get_default_dtype()
+        )
+        self.sin_cached = torch.empty(
+            1, device = get_current_device(), dtype = torch.get_default_dtype()
+        )
 
     def _set_cos_sin_cache(self, seq_len, device, dtype):
         # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
         # in FP32. They are applied (multiplied) in FP32 as well.
         self.current_rope_size = seq_len
         inv_freq = 1.0 / (
-            self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
+            self.base
+            ** (
+                torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float()
+                / self.dim
+            )
         )
-        t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
+        t = torch.arange(
+            self.current_rope_size, device = "cpu", dtype = torch.int64
+        ).float()
 
         freqs = torch.outer(t, inv_freq)
         # Different from paper, but it uses a different permutation in order to obtain the same calculation
-        emb = torch.cat((freqs, freqs), dim=-1)
-        cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True)
-        sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True)
+        emb = torch.cat((freqs, freqs), dim = -1)
+        cos = emb.cos().to(dtype = dtype, device = device, non_blocking = True)
+        sin = emb.sin().to(dtype = dtype, device = device, non_blocking = True)
         self.multi_gpu_cos_cached[device.index] = cos
         self.multi_gpu_sin_cached[device.index] = sin
         return cos, sin
-    pass
 
-    def forward(self, x, position_ids=None, seq_len=None):
+    def forward(self, x, position_ids = None, seq_len = None):
         # x: [bs, num_attention_heads, seq_len, head_size]
         if seq_len is not None and seq_len > self.current_rope_size:
-            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+            self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
 
         device_index = x.device.index
         return (
             self.multi_gpu_cos_cached[device_index][:seq_len],
             self.multi_gpu_sin_cached[device_index][:seq_len],
         )
-    pass
 
     def get_cached(self, seq_len = None, device_index = None):
         if device_index is None:
             device_index = get_current_device()
-        return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[device_index]
-    pass
+        return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[
+            device_index
+        ]
 
     def extend_rope_embedding(self, x, seq_len):
-        if seq_len <= self.current_rope_size: return
+        if seq_len <= self.current_rope_size:
+            return
         # Iteratively grow by increments of 8192
         self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
         for device_idx in range(DEVICE_COUNT):
-            self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype)
-    pass
-pass
+            self._set_cos_sin_cache(
+                self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype
+            )
 
 
 class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
     """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
     # Fixes https://github.com/huggingface/transformers/pull/28837
     # https://github.com/microsoft/DeepSpeed/issues/4932
     # The precision of RoPE buffers is not correct, so we cast to int64.
-    def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
-        config = None, # [TODO] Hack to pass in config - need to remove later
+    def __init__(
+        self,
+        dim = None,
+        max_position_embeddings = 2048,
+        base = 10000,
+        device = None,
+        scaling_factor = 1.0,
+        config = None,  # [TODO] Hack to pass in config - need to remove later
     ):
         self.scaling_factor = scaling_factor
-        super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
-    pass
+        super().__init__(
+            dim = dim,
+            max_position_embeddings = max_position_embeddings,
+            base = base,
+            device = device,
+            config = config,
+        )
 
     def _set_cos_sin_cache(self, seq_len, device, dtype):
         self.current_rope_size = seq_len
         inv_freq = 1.0 / (
-            self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
+            self.base
+            ** (
+                torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float()
+                / self.dim
+            )
         )
-        t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
+        t = torch.arange(
+            self.current_rope_size, device = "cpu", dtype = torch.int64
+        ).float()
         t = t / self.scaling_factor
 
         freqs = torch.outer(t, inv_freq)
         # Different from paper, but it uses a different permutation in order to obtain the same calculation
-        emb = torch.cat((freqs, freqs), dim=-1)
-        cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True)
-        sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True)
+        emb = torch.cat((freqs, freqs), dim = -1)
+        cos = emb.cos().to(dtype = dtype, device = device, non_blocking = True)
+        sin = emb.sin().to(dtype = dtype, device = device, non_blocking = True)
         self.multi_gpu_cos_cached[device.index] = cos
         self.multi_gpu_sin_cached[device.index] = sin
         return cos, sin
-    pass
-pass
 
 
 # See https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py#L736
 # For Llama 3.1
 class LlamaExtendedRotaryEmbedding(torch.nn.Module):
-    def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
-        config = None, # [TODO] Hack to pass in config - need to remove later
+    def __init__(
+        self,
+        dim = None,
+        max_position_embeddings = 2048,
+        base = 10000,
+        device = None,
+        config = None,  # [TODO] Hack to pass in config - need to remove later
     ):
         super().__init__()
         if config is not None:
             # [TODO] Hack to pass in config - need to remove later
             base = config.rope_theta
-            partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+            partial_rotary_factor = (
+                config.partial_rotary_factor
+                if hasattr(config, "partial_rotary_factor")
+                else 1.0
+            )
             dim = int((config.hidden_size // config.num_attention_heads))
             device = DEVICE_TYPE_TORCH
             max_position_embeddings = config.max_position_embeddings
-        pass
 
         self.dim = dim
         self.max_position_embeddings = max_position_embeddings
         self.base = base
         # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
         self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
-        self.multi_gpu_cos_cached = [None]*DEVICE_COUNT
-        self.multi_gpu_sin_cached = [None]*DEVICE_COUNT
+        self.multi_gpu_cos_cached = [None] * DEVICE_COUNT
+        self.multi_gpu_sin_cached = [None] * DEVICE_COUNT
 
         # Normal Llama-3 RoPE
         inv_freq = 1.0 / (
-            self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
+            self.base
+            ** (
+                torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float()
+                / self.dim
+            )
         )
         inv_freq = self.apply_scaling(inv_freq)
         self.register_buffer("inv_freq", inv_freq, persistent = False)
 
         # Build here to make `torch.jit.trace` work.
         for device_idx in range(DEVICE_COUNT):
-            self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype())
+            self._set_cos_sin_cache(
+                seq_len = self.current_rope_size,
+                device = torch.device(device_idx),
+                dtype = torch.get_default_dtype(),
+            )
 
         # dummy so that patch_utils doesn't fail for now
-        self.cos_cached = torch.empty(1, device=get_current_device(), dtype=torch.get_default_dtype())
-        self.sin_cached = torch.empty(1, device=get_current_device(), dtype=torch.get_default_dtype())
-    pass
+        self.cos_cached = torch.empty(
+            1, device = get_current_device(), dtype = torch.get_default_dtype()
+        )
+        self.sin_cached = torch.empty(
+            1, device = get_current_device(), dtype = torch.get_default_dtype()
+        )
 
     def _set_cos_sin_cache(self, seq_len, device, dtype):
         # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
         # in FP32. They are applied (multiplied) in FP32 as well.
         self.current_rope_size = seq_len
 
-        t = torch.arange(self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64).float()
+        t = torch.arange(
+            self.current_rope_size, device = self.inv_freq.device, dtype = torch.int64
+        ).float()
 
         freqs = torch.outer(t, self.inv_freq)
         # Different from paper, but it uses a different permutation in order to obtain the same calculation
-        emb = torch.cat((freqs, freqs), dim=-1)
-        cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True)
-        sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True)
+        emb = torch.cat((freqs, freqs), dim = -1)
+        cos = emb.cos().to(dtype = dtype, device = device, non_blocking = True)
+        sin = emb.sin().to(dtype = dtype, device = device, non_blocking = True)
         self.multi_gpu_cos_cached[device.index] = cos
         self.multi_gpu_sin_cached[device.index] = sin
         return cos, sin
-    pass
 
-    def forward(self, x, position_ids=None, seq_len=None):
+    def forward(self, x, position_ids = None, seq_len = None):
         # x: [bs, num_attention_heads, seq_len, head_size]
         if seq_len is not None and seq_len > self.current_rope_size:
-            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+            self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
         device_index = x.device.index
         return (
             self.multi_gpu_cos_cached[device_index][:seq_len],
             self.multi_gpu_sin_cached[device_index][:seq_len],
         )
-    pass
 
     def get_cached(self, seq_len = None, device_index = None):
         if device_index is None:
             device_index = get_current_device()
-        return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[device_index]
-    pass
+        return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[
+            device_index
+        ]
 
     def extend_rope_embedding(self, x, seq_len):
-        if seq_len <= self.current_rope_size: return
+        if seq_len <= self.current_rope_size:
+            return
         # Iteratively grow by increments of 8192
         self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
         for device_idx in range(DEVICE_COUNT):
-            self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype)
-    pass
+            self._set_cos_sin_cache(
+                self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype
+            )
 
     # From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41
     def apply_scaling(self, freqs: torch.Tensor):
@@ -1589,110 +1804,138 @@ def apply_scaling(self, freqs: torch.Tensor):
                     high_freq_factor - low_freq_factor
                 )
                 new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
-        return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
-    pass
-pass
+        return torch.tensor(new_freqs, dtype = freqs.dtype, device = freqs.device)
 
 
 class LongRopeRotaryEmbedding(torch.nn.Module):
     # For Phi 3.5 128K https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/modeling_phi3.py
-    def __init__(self,
+    def __init__(
+        self,
         dim = None,
         max_position_embeddings = 131072,
         original_max_position_embeddings = 4096,
         base = 10000,
         short_factor = None,
-        long_factor  = None,
+        long_factor = None,
         device = None,
-        config = None, # [TODO] Hack to pass in config - need to remove later
+        config = None,  # [TODO] Hack to pass in config - need to remove later
     ):
         super().__init__()
-        assert(short_factor is not None)
-        assert(long_factor  is not None)
-        assert(type(original_max_position_embeddings) is int)
+        assert short_factor is not None
+        assert long_factor is not None
+        assert type(original_max_position_embeddings) is int
 
         if config is not None:
             # [TODO] Hack to pass in config - need to remove later
             base = config.rope_theta
-            partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+            partial_rotary_factor = (
+                config.partial_rotary_factor
+                if hasattr(config, "partial_rotary_factor")
+                else 1.0
+            )
             dim = int((config.hidden_size // config.num_attention_heads))
             device = DEVICE_TYPE_TORCH
             max_position_embeddings = config.max_position_embeddings
-        pass
 
         self.dim = dim
         self.max_position_embeddings = max_position_embeddings
         self.original_max_position_embeddings = original_max_position_embeddings
         self.base = base
         # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
-        self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings)
-        self.multi_gpu_short_cos_cached = [None]*DEVICE_COUNT
-        self.multi_gpu_short_sin_cached = [None]*DEVICE_COUNT
-        self.multi_gpu_long_cos_cached = [None]*DEVICE_COUNT
-        self.multi_gpu_long_sin_cached = [None]*DEVICE_COUNT
+        self.current_rope_size = min(
+            original_max_position_embeddings, self.max_position_embeddings
+        )
+        self.multi_gpu_short_cos_cached = [None] * DEVICE_COUNT
+        self.multi_gpu_short_sin_cached = [None] * DEVICE_COUNT
+        self.multi_gpu_long_cos_cached = [None] * DEVICE_COUNT
+        self.multi_gpu_long_sin_cached = [None] * DEVICE_COUNT
 
         # Long RoPE similar to RoPE except short sequences have 1 cos / sin
         # and long sequences have another cos / sin
-        inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim
+        inv_freq_shape = (
+            torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float()
+            / self.dim
+        )
         short_factor = torch.tensor(short_factor, device = "cpu", dtype = torch.float32)
-        long_factor  = torch.tensor(long_factor,  device = "cpu", dtype = torch.float32)
+        long_factor = torch.tensor(long_factor, device = "cpu", dtype = torch.float32)
         short_inv_freq = 1.0 / (short_factor * self.base**inv_freq_shape)
-        long_inv_freq  = 1.0 / (long_factor  * self.base**inv_freq_shape)
+        long_inv_freq = 1.0 / (long_factor * self.base**inv_freq_shape)
 
         # Phi-3 Scale factor
         scale = self.max_position_embeddings / self.original_max_position_embeddings
         if scale <= 1.0:
             scaling_factor = 1.0
         else:
-            scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
-        pass
+            scaling_factor = math.sqrt(
+                1 + math.log(scale) / math.log(self.original_max_position_embeddings)
+            )
         self.scaling_factor = scaling_factor
 
         # Short and long inv_freq
         self.register_buffer("short_inv_freq", short_inv_freq, persistent = False)
-        self.register_buffer("long_inv_freq",  long_inv_freq,  persistent = False)
+        self.register_buffer("long_inv_freq", long_inv_freq, persistent = False)
 
         # Build here to make `torch.jit.trace` work.
         # Initialize short sequences cache for all devices
         dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16
-        t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float()
+        t = torch.arange(
+            original_max_position_embeddings,
+            device = self.short_inv_freq.device,
+            dtype = torch.int64,
+        ).float()
         freqs = torch.outer(t, self.short_inv_freq)
-        emb = torch.cat((freqs, freqs), dim=-1)
+        emb = torch.cat((freqs, freqs), dim = -1)
 
         for device_idx in range(DEVICE_COUNT):
             device_obj = torch.device(device_idx)
-            cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True)
-            sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True)
+            cos_cached = (emb.cos() * self.scaling_factor).to(
+                dtype = dtype, device = device_obj, non_blocking = True
+            )
+            sin_cached = (emb.sin() * self.scaling_factor).to(
+                dtype = dtype, device = device_obj, non_blocking = True
+            )
             self.multi_gpu_short_cos_cached[device_idx] = cos_cached
             self.multi_gpu_short_sin_cached[device_idx] = sin_cached
 
         # dummy so that patch_utils doesn't fail for now
-        self.short_cos_cached = torch.empty(1, device=get_current_device(), dtype=torch.get_default_dtype())
-        self.short_sin_cached = torch.empty(1, device=get_current_device(), dtype=torch.get_default_dtype())
-        self.long_cos_cached = torch.empty(1, device=get_current_device(), dtype=torch.get_default_dtype())
-        self.long_sin_cached = torch.empty(1, device=get_current_device(), dtype=torch.get_default_dtype())
-    pass
+        self.short_cos_cached = torch.empty(
+            1, device = get_current_device(), dtype = torch.get_default_dtype()
+        )
+        self.short_sin_cached = torch.empty(
+            1, device = get_current_device(), dtype = torch.get_default_dtype()
+        )
+        self.long_cos_cached = torch.empty(
+            1, device = get_current_device(), dtype = torch.get_default_dtype()
+        )
+        self.long_sin_cached = torch.empty(
+            1, device = get_current_device(), dtype = torch.get_default_dtype()
+        )
 
     def _set_cos_sin_cache(self, seq_len, device, dtype):
         # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
         # in FP32. They are applied (multiplied) in FP32 as well.
         self.current_rope_size = seq_len
 
-        t = torch.arange(self.current_rope_size, device=self.long_inv_freq.device, dtype=torch.int64).float()
+        t = torch.arange(
+            self.current_rope_size, device = self.long_inv_freq.device, dtype = torch.int64
+        ).float()
         # Long sequences
         freqs = torch.outer(t, self.long_inv_freq)
-        emb = torch.cat((freqs, freqs), dim=-1)
-        cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
-        sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
+        emb = torch.cat((freqs, freqs), dim = -1)
+        cos_cached = (emb.cos() * self.scaling_factor).to(
+            dtype = dtype, device = device, non_blocking = True
+        )
+        sin_cached = (emb.sin() * self.scaling_factor).to(
+            dtype = dtype, device = device, non_blocking = True
+        )
         self.multi_gpu_long_cos_cached[device.index] = cos_cached
         self.multi_gpu_long_sin_cached[device.index] = sin_cached
         return cos_cached, sin_cached
-    pass
 
-    def forward(self, x, position_ids=None, seq_len=None):
+    def forward(self, x, position_ids = None, seq_len = None):
         # x: [bs, num_attention_heads, seq_len, head_size]
         if seq_len is not None and seq_len > self.current_rope_size:
-            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+            self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
 
         device_index = x.device.index
 
@@ -1706,25 +1949,27 @@ def forward(self, x, position_ids=None, seq_len=None):
                 self.multi_gpu_long_cos_cached[device_index][:seq_len],
                 self.multi_gpu_long_sin_cached[device_index][:seq_len],
             )
-        pass
-    pass
 
     def get_cached(self, seq_len = None, device_index = None):
         if device_index is None:
             device_index = get_current_device()
         if seq_len is not None and seq_len < self.original_max_position_embeddings:
-            return self.multi_gpu_short_cos_cached[device_index], self.multi_gpu_short_sin_cached[device_index]
-        return self.multi_gpu_long_cos_cached[device_index], self.multi_gpu_long_sin_cached[device_index]
-    pass
+            return self.multi_gpu_short_cos_cached[
+                device_index
+            ], self.multi_gpu_short_sin_cached[device_index]
+        return self.multi_gpu_long_cos_cached[
+            device_index
+        ], self.multi_gpu_long_sin_cached[device_index]
 
     def extend_rope_embedding(self, x, seq_len):
-        if seq_len <= self.current_rope_size: return
+        if seq_len <= self.current_rope_size:
+            return
         # Iteratively grow by increments of 8192
         self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
         for device_idx in range(DEVICE_COUNT):
-            self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype)
-    pass
-pass
+            self._set_cos_sin_cache(
+                self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype
+            )
 
 
 def unsloth_fast_generate(
@@ -1737,13 +1982,19 @@ def unsloth_fast_generate(
     dtype = _get_dtype(dtype_from_config(self.config))
 
     if hasattr(self, "config") and hasattr(self.config, "max_position_embeddings"):
-        if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs:
-            if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > self.config.max_position_embeddings:
+        if (
+            "input_ids" in kwargs
+            and kwargs["input_ids"] is not None
+            and "max_new_tokens" in kwargs
+        ):
+            if (
+                kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"]
+                > self.config.max_position_embeddings
+            ):
                 raise ValueError(
-                    f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {self.config.max_position_embeddings}!\n'\
-                    'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.'
+                    f"Unsloth: input length {kwargs['input_ids'].shape[-1]} + max_new_tokens {kwargs['max_new_tokens']} exceeds the maximum sequence length of {self.config.max_position_embeddings}!\n"
+                    "You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`."
                 )
-    pass
 
     # Must patch accelerate for Xformers
     # if accelerate_new_send_to_device is not None:
@@ -1755,7 +2006,7 @@ def unsloth_fast_generate(
     kwargs["cache_implementation"] = "dynamic"
     # For num_logits_to_keep
     num_logits_to_keep = kwargs.get("num_logits_to_keep", None)
-    logits_to_keep     = kwargs.get("logits_to_keep",     None)
+    logits_to_keep = kwargs.get("logits_to_keep", None)
     if num_logits_to_keep is None and logits_to_keep is None:
         kwargs["num_logits_to_keep"] = 1
 
@@ -1770,9 +2021,11 @@ def unsloth_fast_generate(
     kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
 
     # Mixed precision autocast
-    with torch.inference_mode(), torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype):
+    with (
+        torch.inference_mode(),
+        torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype),
+    ):
         output = self._old_generate(*args, **kwargs)
-    pass
 
     # Return accelerate back
     # if accelerate_new_send_to_device is not None:
@@ -1782,31 +2035,30 @@ def unsloth_fast_generate(
     FastLlamaModel.for_training(self)
 
     return output
-pass
 
 
 class FastLlamaModel:
-
     @staticmethod
     def pre_patch():
         init_name, function = patch_llama_rope_scaling(
-            model_name           = "llama",
-            rope_module          = LlamaRotaryEmbedding,
-            scaled_rope_module   = LlamaLinearScalingRotaryEmbedding,
+            model_name = "llama",
+            rope_module = LlamaRotaryEmbedding,
+            scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
             extended_rope_module = LlamaExtendedRotaryEmbedding,
-            attention_module     = LlamaAttention,
-            longrope_module      = LongRopeRotaryEmbedding,
+            attention_module = LlamaAttention,
+            longrope_module = LongRopeRotaryEmbedding,
         )
         if init_name is not None:
             exec(function, globals())
-            LlamaAttention.__init__  = eval(init_name)
-        pass
-        LlamaAttention      .forward = LlamaAttention_fast_forward
-        LlamaSdpaAttention  .forward = LlamaAttention_fast_forward
+            LlamaAttention.__init__ = eval(init_name)
+        LlamaAttention.forward = LlamaAttention_fast_forward
+        LlamaSdpaAttention.forward = LlamaAttention_fast_forward
         LlamaFlashAttention2.forward = LlamaAttention_fast_forward
-        LlamaDecoderLayer   .forward = LlamaDecoderLayer_fast_forward
-        LlamaModel          .forward = LlamaModel_fast_forward
-        LlamaForCausalLM    .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
+        LlamaDecoderLayer.forward = LlamaDecoderLayer_fast_forward
+        LlamaModel.forward = LlamaModel_fast_forward
+        LlamaForCausalLM.forward = CausalLM_fast_forward(
+            LlamaModel_fast_forward_inference
+        )
         PeftModelForCausalLM.forward = PeftModel_fast_forward
         fix_prepare_inputs_for_generation(LlamaForCausalLM)
 
@@ -1816,47 +2068,50 @@ def pre_patch():
         # https://github.com/huggingface/transformers/pull/27931
         # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
         import transformers.models.llama.modeling_llama
-        transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding
-        transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = LlamaLinearScalingRotaryEmbedding
-        return
-    pass
 
+        transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = (
+            LlamaRotaryEmbedding
+        )
+        transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = (
+            LlamaLinearScalingRotaryEmbedding
+        )
+        return
 
     @staticmethod
     def from_pretrained(
-        model_name         = "unsloth/llama-3-8b-bnb-4bit",
-        max_seq_length     = None,
-        dtype              = None,
-        load_in_4bit       = True,
-        token              = None,
-        device_map         = "sequential",
-        rope_scaling       = None,
-        fix_tokenizer      = True,
-        model_patcher      = None,
-        tokenizer_name     = None,
-        trust_remote_code  = False,
-        revision           = None,
-
-        fast_inference    = False, # uses vLLM
+        model_name = "unsloth/llama-3-8b-bnb-4bit",
+        max_seq_length = None,
+        dtype = None,
+        load_in_4bit = True,
+        token = None,
+        device_map = "sequential",
+        rope_scaling = None,
+        fix_tokenizer = True,
+        model_patcher = None,
+        tokenizer_name = None,
+        trust_remote_code = False,
+        revision = None,
+        fast_inference = False,  # uses vLLM
         gpu_memory_utilization = 0.5,
-        float8_kv_cache   = False,
-        random_state      = 3407,
-        max_lora_rank     = 16,
+        float8_kv_cache = False,
+        random_state = 3407,
+        max_lora_rank = 16,
         disable_log_stats = False,
         unsloth_vllm_standby = False,
-        num_labels =  None,
+        num_labels = None,
         qat_scheme = None,
         **kwargs,
     ):
         os.environ["UNSLOTH_USE_NEW_MODEL"] = "0"
         if trust_remote_code:
             if fast_inference:
-                raise NotImplementedError("Unsloth: Fast inference does not support `trust_remote_code` yet.")
+                raise NotImplementedError(
+                    "Unsloth: Fast inference does not support `trust_remote_code` yet."
+                )
             print(
-                "Unsloth: WARNING `trust_remote_code` is True.\n"\
+                "Unsloth: WARNING `trust_remote_code` is True.\n"
                 "Are you certain you want to do remote code execution?"
             )
-        pass
         if fast_inference:
             if not is_vLLM_available():
                 print("Unsloth: vLLM is not installed! Will use Unsloth inference!")
@@ -1864,79 +2119,108 @@ def from_pretrained(
             if DEVICE_TYPE == "cuda":
                 major_version, minor_version = torch.cuda.get_device_capability()
                 if major_version < 7:
-                    print("Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!")
+                    print(
+                        "Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!"
+                    )
                     fast_inference = False
             elif DEVICE_TYPE == "hip":
                 fast_inference = True
-            if unsloth_vllm_standby and os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "0":
-                raise RuntimeError("Unsloth: `unsloth_vllm_standby` is True, but  environment variable `UNSLOTH_VLLM_STANDBY` is not set to 1!")
-        pass
+            if (
+                unsloth_vllm_standby
+                and os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "0"
+            ):
+                raise RuntimeError(
+                    "Unsloth: `unsloth_vllm_standby` is True, but  environment variable `UNSLOTH_VLLM_STANDBY` is not set to 1!"
+                )
 
-        if token is None: token = get_token()
-        if model_patcher is None: model_patcher = FastLlamaModel
+        if token is None:
+            token = get_token()
+        if model_patcher is None:
+            model_patcher = FastLlamaModel
         SUPPORTS_BFLOAT16 = is_bfloat16_supported()
 
         if DEVICE_TYPE == "cuda":
             gpu_stats = torch.cuda.get_device_properties(0)
-            gpu_stats_name = gpu_stats.name + ". " if gpu_stats.name != "" else "NVIDIA GPU Device. "
+            gpu_stats_name = (
+                gpu_stats.name + ". " if gpu_stats.name != "" else "NVIDIA GPU Device. "
+            )
             gpu_version = torch.version.cuda
             gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}."
-            try:    vllm_version = f" vLLM: {importlib_version('vllm')}."
-            except: vllm_version = ""
+            try:
+                vllm_version = f" vLLM: {importlib_version('vllm')}."
+            except:
+                vllm_version = ""
         elif DEVICE_TYPE == "hip":
             gpu_stats = torch.cuda.get_device_properties(0)
-            gpu_stats_name = gpu_stats.name + ". " if gpu_stats.name != "" else "AMD GPU Device. "
+            gpu_stats_name = (
+                gpu_stats.name + ". " if gpu_stats.name != "" else "AMD GPU Device. "
+            )
             gpu_version = torch.version.hip
             gpu_stats_snippet = f"ROCm Toolkit: {gpu_version}."
-            try:    vllm_version = f" vLLM: {importlib_version('vllm')}."
-            except: vllm_version = ""
+            try:
+                vllm_version = f" vLLM: {importlib_version('vllm')}."
+            except:
+                vllm_version = ""
         elif DEVICE_TYPE == "xpu":
             gpu_stats = torch.xpu.get_device_properties(0)
-            gpu_stats_name = gpu_stats.name + ". " if gpu_stats.name != "" else "Intel XPU Device. "
+            gpu_stats_name = (
+                gpu_stats.name + ". " if gpu_stats.name != "" else "Intel XPU Device. "
+            )
             gpu_version = torch.version.xpu
             gpu_stats_snippet = f"Intel Toolkit: {gpu_version}."
-            try:    vllm_version = f" vLLM: {importlib_version('vllm')}."
-            except: vllm_version = ""
+            try:
+                vllm_version = f" vLLM: {importlib_version('vllm')}."
+            except:
+                vllm_version = ""
         else:
             raise ValueError(f"Unsloth: Unsupported device type: {DEVICE_TYPE}")
 
         max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
 
-        statistics = \
-        f"==((====))==  Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"\
-        f"   {chr(92)}{chr(92)}   /|    {gpu_stats_name}Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\
-        f"O^O/ {chr(92)}_/ {chr(92)}    Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"\
-        f"{chr(92)}        /    Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
-        f' "-____-"     Free license: http://github.com/unslothai/unsloth'
+        statistics = (
+            f"==((====))==  Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"
+            f"   {chr(92)}{chr(92)}   /|    {gpu_stats_name}Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"
+            f"O^O/ {chr(92)}_/ {chr(92)}    Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"
+            f"{chr(92)}        /    Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"
+            f' "-____-"     Free license: http://github.com/unslothai/unsloth'
+        )
 
         print(statistics)
 
         # Warn about fast transfers
         if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ:
             old_hf_transfer = os.environ["HF_HUB_ENABLE_HF_TRANSFER"]
-            if old_hf_transfer in ("False", "false"): old_hf_transfer = "0"
-            if old_hf_transfer in ("True",  "true" ): old_hf_transfer = "1"
+            if old_hf_transfer in ("False", "false"):
+                old_hf_transfer = "0"
+            if old_hf_transfer in ("True", "true"):
+                old_hf_transfer = "1"
         else:
             old_hf_transfer = "0"
         if old_hf_transfer == "1":
-            print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!")
-        pass
-        if old_hf_transfer != "0": os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+            print(
+                "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!"
+            )
+        if old_hf_transfer != "0":
+            os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 
         model_patcher.pre_patch()
-         # For debugging - we use a download counter to see if environments are not breaking or if HF is down
+        # For debugging - we use a download counter to see if environments are not breaking or if HF is down
         get_statistics(kwargs.get("local_files_only", False))
 
         if dtype is None:
             dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
         elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
-            logger.warning_once("Device does not support bfloat16. Will change to float16.")
+            logger.warning_once(
+                "Device does not support bfloat16. Will change to float16."
+            )
             dtype = torch.float16
         # elif dtype == torch.float16 and SUPPORTS_BFLOAT16:
         #     logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.")
         #     dtype = torch.bfloat16
 
-        assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
+        assert (
+            dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32
+        )
 
         # RoPE Scaling
         model_config = AutoConfig.from_pretrained(
@@ -1955,64 +2239,66 @@ def from_pretrained(
         try:
             with open(inspect.getfile(model_function), "r", encoding = "utf-8") as file:
                 has_rope_scaling = "self.config.rope_scaling" in file.read()
-        except: pass
+        except:
+            pass
         has_rope_scaling = True
 
         # If max_seq_length is not specified, use maximum from config
         if max_seq_length is None:
             max_seq_length = model_max_seq_length
-        pass
 
         if (rope_scaling is None) and (max_seq_length > model_max_seq_length):
-
             rope_scaling = max_seq_length / model_max_seq_length
 
             if fast_inference:
-                raise NotImplementedError("Unsloth: Fast inference does not yet work with RoPE Scaling.")
+                raise NotImplementedError(
+                    "Unsloth: Fast inference does not yet work with RoPE Scaling."
+                )
 
             logger.warning_once(
-                f"Unsloth: {model_name} can only handle sequence lengths of at most "\
-                f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\
-                f"{round(rope_scaling, 3)}, it can be magically be extended to "\
+                f"Unsloth: {model_name} can only handle sequence lengths of at most "
+                f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "
+                f"{round(rope_scaling, 3)}, it can be magically be extended to "
                 f"{max_seq_length}!"
             )
 
             # Warn RoPE scaling isn't allowed
             if not has_rope_scaling:
                 raise RuntimeError(
-                    f"However, {model_name} doesn't support RoPE Scaling!\n"\
+                    f"However, {model_name} doesn't support RoPE Scaling!\n"
                     "Please file a feature request at https://github.com/unslothai/unsloth."
                 )
-            pass
 
-            rope_scaling = {"type": "linear", "factor": rope_scaling,}
+            rope_scaling = {
+                "type": "linear",
+                "factor": rope_scaling,
+            }
 
             # Add to kwargs
             kwargs["rope_scaling"] = rope_scaling
-        pass
 
         bnb_config = None
         if load_in_4bit:
-            llm_int8_skip_modules =  SKIP_QUANTIZATION_MODULES.copy()
+            llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy()
             if IS_FALCON_H1:
                 # we cannot quantize out_proj layer due to mamba kernels: https://github.com/tiiuae/Falcon-H1/issues/13#issuecomment-2918671274
                 llm_int8_skip_modules.append("out_proj")
             bnb_config = BitsAndBytesConfig(
-                load_in_4bit              = True,
+                load_in_4bit = True,
                 bnb_4bit_use_double_quant = True,
-                bnb_4bit_quant_type       = "nf4",
-                bnb_4bit_compute_dtype    = dtype,
-                llm_int8_skip_modules     = llm_int8_skip_modules,
+                bnb_4bit_quant_type = "nf4",
+                bnb_4bit_compute_dtype = dtype,
+                llm_int8_skip_modules = llm_int8_skip_modules,
             )
-        pass
 
         # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12
         # RoPE Scaling's max_position_embeddings must be updated
         max_position_embeddings = max(max_seq_length, model_max_seq_length)
-        kwargs.pop("attn_implementation", None); # No need since we auto call it
+        kwargs.pop("attn_implementation", None)  # No need since we auto call it
 
         # Cannot be None, since HF now checks for the config
-        if load_in_4bit: kwargs["quantization_config"] = bnb_config
+        if load_in_4bit:
+            kwargs["quantization_config"] = bnb_config
 
         kwargs = add_dtype_kwargs(dtype, kwargs)
 
@@ -2020,26 +2306,26 @@ def from_pretrained(
         if num_labels is not None:
             model = AutoModelForSequenceClassification.from_pretrained(
                 model_name,
-                device_map              = device_map,
+                device_map = device_map,
                 # torch_dtype             = dtype, # transformers changed torch_dtype to dtype
-                num_labels              = num_labels,
-                #quantization_config     = bnb_config,
-                token                   = token,
+                num_labels = num_labels,
+                # quantization_config     = bnb_config,
+                token = token,
                 max_position_embeddings = max_position_embeddings,
-                trust_remote_code       = trust_remote_code,
-                attn_implementation     = "eager",
+                trust_remote_code = trust_remote_code,
+                attn_implementation = "eager",
                 **kwargs,
             )
         elif not fast_inference:
             model = AutoModelForCausalLM.from_pretrained(
                 model_name,
-                device_map              = device_map,
+                device_map = device_map,
                 # torch_dtype             = dtype, # transformers changed torch_dtype to dtype
                 # quantization_config     = bnb_config,
-                token                   = token,
+                token = token,
                 max_position_embeddings = max_position_embeddings,
-                trust_remote_code       = trust_remote_code,
-                attn_implementation     = "eager",
+                trust_remote_code = trust_remote_code,
+                attn_implementation = "eager",
                 **kwargs,
             )
             model.fast_generate = model.generate
@@ -2051,35 +2337,38 @@ def from_pretrained(
                 convert_vllm_to_huggingface,
                 generate_batches,
             )
+
             allowed_args = inspect.getfullargspec(load_vllm).args
             load_vllm_kwargs = dict(
-                model_name             = model_name,
-                config                 = model_config,
+                model_name = model_name,
+                config = model_config,
                 gpu_memory_utilization = gpu_memory_utilization,
-                max_seq_length         = max_seq_length,
-                dtype                  = dtype,
-                float8_kv_cache        = float8_kv_cache,
-                enable_lora            = True,
-                max_lora_rank          = max_lora_rank,
-                disable_log_stats      = disable_log_stats,
-                use_bitsandbytes       = load_in_4bit,
-                unsloth_vllm_standby   = unsloth_vllm_standby,
+                max_seq_length = max_seq_length,
+                dtype = dtype,
+                float8_kv_cache = float8_kv_cache,
+                enable_lora = True,
+                max_lora_rank = max_lora_rank,
+                disable_log_stats = disable_log_stats,
+                use_bitsandbytes = load_in_4bit,
+                unsloth_vllm_standby = unsloth_vllm_standby,
             )
             for allowed_arg in allowed_args:
                 if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:
                     load_vllm_kwargs[allowed_arg] = kwargs[allowed_arg]
-            pass
 
             # Load vLLM first
             llm = load_vllm(**load_vllm_kwargs)
 
             # Convert to HF format
             _, quant_state_dict = get_vllm_state_dict(llm, config = model_config)
-            model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype, bnb_config)
+            model = convert_vllm_to_huggingface(
+                quant_state_dict, model_config, dtype, bnb_config
+            )
             model.vllm_engine = llm
             model.fast_generate = model.vllm_engine.generate
-            model.fast_generate_batches = functools.partial(generate_batches, model.vllm_engine)
-        pass
+            model.fast_generate_batches = functools.partial(
+                generate_batches, model.vllm_engine
+            )
         raise_handler.remove()
         # Return old flag
         os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
@@ -2087,12 +2376,12 @@ def from_pretrained(
         # Counteract saved tokenizers
         tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
         tokenizer = load_correct_tokenizer(
-            tokenizer_name    = tokenizer_name,
-            model_max_length  = max_position_embeddings,
-            padding_side      = "right",
-            token             = token,
+            tokenizer_name = tokenizer_name,
+            model_max_length = max_position_embeddings,
+            padding_side = "right",
+            token = token,
             trust_remote_code = trust_remote_code,
-            fix_tokenizer     = fix_tokenizer,
+            fix_tokenizer = fix_tokenizer,
         )
 
         model, tokenizer = patch_tokenizer(model, tokenizer)
@@ -2101,11 +2390,11 @@ def from_pretrained(
         # Patch up QKV / O and MLP
         for idx, layer in enumerate(model.model.layers):
             layer.self_attn.apply_qkv = original_apply_qkv
-            layer.self_attn.apply_o   = original_apply_o
-        pass
+            layer.self_attn.apply_o = original_apply_o
 
         # Patch Trainer
         from transformers.trainer import Trainer
+
         try:
             if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
                 inner_training_loop = inspect.getsource(Trainer._inner_training_loop)
@@ -2113,22 +2402,29 @@ def from_pretrained(
             else:
                 inner_training_loop = Trainer._original_training_loop
         except:
-            raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop')
-        pass
+            raise RuntimeError("Unsloth: Unsuccessfully patched inner_training_loop")
 
         import transformers.trainer
+
         items_in_trainer = dir(transformers.trainer)
         good_items = []
         for item in items_in_trainer:
-            if item in inner_training_loop: good_items.append(item)
-        pass
-        exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
+            if item in inner_training_loop:
+                good_items.append(item)
+        exec(
+            "from transformers.trainer import ("
+            + ", ".join(x for x in good_items)
+            + ")",
+            globals(),
+        )
 
-        start = re.search(r'logger\.info\([\"\'].+?Running training', inner_training_loop).span(0)[0]
+        start = re.search(
+            r"logger\.info\([\"\'].+?Running training", inner_training_loop
+        ).span(0)[0]
         end = inner_training_loop.find("\n\n", start)
         original_debug = inner_training_loop[start:end]
-        spaces = re.search(r'\n([\s\t]{1,})', original_debug).group(0)[1:]
-        front_spaces = re.match(r'([\s\t]{1,})', inner_training_loop).group(0)
+        spaces = re.search(r"\n([\s\t]{1,})", original_debug).group(0)[1:]
+        front_spaces = re.match(r"([\s\t]{1,})", inner_training_loop).group(0)
 
         # Cannot use \\ since it will cause a SyntaxWarning in Python 3.12
         # Instead use chr(92) == \\
@@ -2147,8 +2443,10 @@ def from_pretrained(
             else:
                 torch.cuda.empty_cache()"""
 
-        debug_info = debug_info.split('\n')
-        debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
+        debug_info = debug_info.split("\n")
+        debug_info = "\n".join(
+            [debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]
+        )
         inner_training_loop = inner_training_loop.replace(original_debug, debug_info)
 
         debug_info = """n_total_devices = total_train_batch_size // \\
@@ -2156,19 +2454,24 @@ def from_pretrained(
         if n_total_devices > 1:
             logger.warning_once('Unsloth is running with multi GPUs - the effective batch size is multiplied by ' + str(n_total_devices))
         debug_info ="""
-        debug_info = debug_info.split('\n')
-        debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
+        debug_info = debug_info.split("\n")
+        debug_info = "\n".join(
+            [debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]
+        )
         inner_training_loop = inner_training_loop.replace("debug_info =", debug_info, 1)
 
         front_spaces = re.match(r"[\t\s]{1,}", inner_training_loop).group(0)
-        inner_training_loop = re.sub(r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE)
+        inner_training_loop = re.sub(
+            r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE
+        )
         inner_training_loop = inner_training_loop.replace(
             "train_dataloader = tpu_spmd_dataloader(train_dataloader)",
-            "raise RuntimeError('Unsloth: TPUs are not yet supported!')"
+            "raise RuntimeError('Unsloth: TPUs are not yet supported!')",
         )
         inner_training_loop = inner_training_loop.replace(
             "_inner_training_loop",
-            "_fast_inner_training_loop", 1,
+            "_fast_inner_training_loop",
+            1,
         )
         inner_training_loop = inner_training_loop.replace(
             "is_torch_tpu_available()",
@@ -2183,7 +2486,6 @@ def from_pretrained(
         while hasattr(m, "model"):
             m.max_seq_length = max_seq_length
             m = m.model
-        pass
         m.max_seq_length = max_seq_length
         # Save to modules as well
         for module in model.modules():
@@ -2192,14 +2494,13 @@ def from_pretrained(
         # We check the tokenizer first for errors
         if fix_tokenizer:
             tokenizer = check_tokenizer(
-                model            = model,
-                tokenizer        = tokenizer,
-                model_name       = model_name,
+                model = model,
+                tokenizer = tokenizer,
+                model_name = model_name,
                 model_max_length = max_position_embeddings,
-                padding_side     = "right",
-                token            = token,
+                padding_side = "right",
+                token = token,
             )
-        pass
         patch_saving_functions(tokenizer)
 
         # Fix up config for transformers uploading PEFT
@@ -2207,13 +2508,11 @@ def from_pretrained(
         if False:
             name = model.config._name_or_path
             if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
-                name = name[:len(name) - len("-bnb-4bit")]
-                model.config.update({"_name_or_path" : name})
-            pass
-        pass
+                name = name[: len(name) - len("-bnb-4bit")]
+                model.config.update({"_name_or_path": name})
 
         # Log Unsloth version for future fastpaths for inference
-        model.config.update({"unsloth_version" : __version__})
+        model.config.update({"unsloth_version": __version__})
 
         # Add save modules
         patch_saving_functions(model)
@@ -2223,7 +2522,7 @@ def from_pretrained(
         patch_gradient_accumulation_fix(Trainer)
 
         # Save tokenizer for inference purposes
-        tokenizer.padding_side = "left" # Force inference
+        tokenizer.padding_side = "left"  # Force inference
         internal_model = model
         while hasattr(internal_model, "model"):
             internal_model._saved_temp_tokenizer = tokenizer
@@ -2231,7 +2530,6 @@ def from_pretrained(
             internal_model.is_loaded_in_8bit = True
 
             internal_model = internal_model.model
-        pass
         internal_model._saved_temp_tokenizer = tokenizer
         # Also set is_loaded_in_8bit to disable incorrect DDP
         internal_model.is_loaded_in_8bit = True
@@ -2241,132 +2539,149 @@ def from_pretrained(
             rotary_emb = model.model.rotary_emb
             for layer in model.model.layers:
                 layer.self_attn.rotary_emb = rotary_emb
-        pass
 
         # Add for_inference and for_training
-        model.for_training  = functools.partial(FastLlamaModel.for_training,  model)
+        model.for_training = functools.partial(FastLlamaModel.for_training, model)
         model.for_inference = functools.partial(FastLlamaModel.for_inference, model)
         m = model
         while hasattr(m, "model"):
-            m.for_training  = functools.partial(FastBaseModel.for_training,  m)
+            m.for_training = functools.partial(FastBaseModel.for_training, m)
             m.for_inference = functools.partial(FastBaseModel.for_inference, m)
             m = m.model
 
         # Patch generate
-        is_classification =  "Classification" in str(type(model))
+        is_classification = "Classification" in str(type(model))
         if not is_classification and model.generate.__name__ != "unsloth_fast_generate":
             model._old_generate = model.generate
             unsloth_fast_generate.__doc__ = model._old_generate.__doc__
             model.generate = types.MethodType(unsloth_fast_generate, model)
-        pass
         # Set weight[padding_idx] = 0
         with torch.no_grad():
             for name, module in model.named_modules():
                 if type(module) is torch.nn.Embedding:
-                    if getattr(module, "weight", None) is not None and getattr(module, "padding_idx", None) is not None:
+                    if (
+                        getattr(module, "weight", None) is not None
+                        and getattr(module, "padding_idx", None) is not None
+                    ):
                         if module.padding_idx < module.weight.shape[0]:
                             module.weight[module.padding_idx] = 0
         return model, tokenizer
-    pass
-
 
     @staticmethod
     def post_patch(model, tokenizer):
-        model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = True)
+        model, tokenizer = patch_model_and_tokenizer(
+            model, tokenizer, downcast_rope = True
+        )
         return model, tokenizer
-    pass
-
 
     @staticmethod
     def get_peft_model(
         model,
-        r                   = 16,
-        target_modules      = ["q_proj", "k_proj", "v_proj", "o_proj",
-                               "gate_proj", "up_proj", "down_proj"],
-        lora_alpha          = 16,
-        lora_dropout        = 0.0,
-        bias                = "none",
+        r = 16,
+        target_modules = [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+            "o_proj",
+            "gate_proj",
+            "up_proj",
+            "down_proj",
+        ],
+        lora_alpha = 16,
+        lora_dropout = 0.0,
+        bias = "none",
         layers_to_transform = None,
-        layers_pattern      = None,
+        layers_pattern = None,
         use_gradient_checkpointing = "unsloth",
-        random_state        = 3407,
-        max_seq_length      = 2048, # not used anymore
-        use_rslora          = False,
-        modules_to_save     = None,
-        init_lora_weights   = True,
-        loftq_config        = {},
-        temporary_location  = "_unsloth_temporary_saved_buffers",
-        qat_scheme          = None,
+        random_state = 3407,
+        max_seq_length = 2048,  # not used anymore
+        use_rslora = False,
+        modules_to_save = None,
+        init_lora_weights = True,
+        loftq_config = {},
+        temporary_location = "_unsloth_temporary_saved_buffers",
+        qat_scheme = None,
         **kwargs,
     ):
         if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1":
             # Check for other PEFT args in kwargs
-            for (peft_arg, flag) in (
+            for peft_arg, flag in (
                 ("finetune_vision_layers", False),
                 ("finetune_language_layers", True),
                 ("finetune_attention_modules", True),
                 ("finetune_mlp_modules", True),
             ):
-                if peft_arg not in kwargs: kwargs[peft_arg] = flag
+                if peft_arg not in kwargs:
+                    kwargs[peft_arg] = flag
             return FastBaseModel.get_peft_model(
-                model                      = model,
-                r                          = r,
-                target_modules             = target_modules,
-                lora_alpha                 = lora_alpha,
-                lora_dropout               = lora_dropout,
-                bias                       = bias,
-                layers_to_transform        = layers_to_transform,
-                layers_pattern             = layers_pattern,
+                model = model,
+                r = r,
+                target_modules = target_modules,
+                lora_alpha = lora_alpha,
+                lora_dropout = lora_dropout,
+                bias = bias,
+                layers_to_transform = layers_to_transform,
+                layers_pattern = layers_pattern,
                 use_gradient_checkpointing = use_gradient_checkpointing,
-                random_state               = random_state,
-                max_seq_length             = max_seq_length,
-                use_rslora                 = use_rslora,
-                modules_to_save            = modules_to_save,
-                init_lora_weights          = init_lora_weights,
-                loftq_config               = loftq_config,
-                temporary_location         = temporary_location,
+                random_state = random_state,
+                max_seq_length = max_seq_length,
+                use_rslora = use_rslora,
+                modules_to_save = modules_to_save,
+                init_lora_weights = init_lora_weights,
+                loftq_config = loftq_config,
+                temporary_location = temporary_location,
                 **kwargs,
             )
-        pass
         if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1":
-            print("Unsloth: Full finetuning is enabled, so .get_peft_model has no effect")
+            print(
+                "Unsloth: Full finetuning is enabled, so .get_peft_model has no effect"
+            )
             return model
-        pass
         transformers_set_seed(random_state)
 
         if use_gradient_checkpointing == "unsloth":
-            patch_unsloth_smart_gradient_checkpointing(dtype = model.get_input_embeddings().weight.dtype)
+            patch_unsloth_smart_gradient_checkpointing(
+                dtype = model.get_input_embeddings().weight.dtype
+            )
 
         if type(r) is not int:
             raise TypeError(f"Unsloth: Rank of {str(r)} must be an integer.")
         if r <= 0:
             raise TypeError(f"Unsloth: Rank of {str(r)} must be larger than 0.")
 
-        if isinstance(model, PeftModelForCausalLM) or isinstance(model, PeftModelForSequenceClassification):
+        if isinstance(model, PeftModelForCausalLM) or isinstance(
+            model, PeftModelForSequenceClassification
+        ):
             # Check if exactly the same and then pass through!
-            assert(hasattr(model, "peft_config"))
+            assert hasattr(model, "peft_config")
 
             peft_config = model.peft_config["default"].to_dict()
             check_parameters = [
-                "r", "lora_alpha", "lora_dropout",
-                "bias", "layers_to_transform", "layers_pattern",
-                "use_rslora", "init_lora_weights",
+                "r",
+                "lora_alpha",
+                "lora_dropout",
+                "bias",
+                "layers_to_transform",
+                "layers_pattern",
+                "use_rslora",
+                "init_lora_weights",
             ]
             check_all = True
             for param in check_parameters:
                 check_all = check_all and (peft_config[param] == eval(param))
-            pass
 
             # Check save_modules
             old_target_modules = list(peft_config["target_modules"])
             modules_to_save = peft_config["modules_to_save"]
-            if modules_to_save is None: modules_to_save = {}
+            if modules_to_save is None:
+                modules_to_save = {}
             modules_to_save = list(modules_to_save)
             old_target_modules += modules_to_save
 
             # Combine all
-            new_target_modules = list(target_modules) + \
-                list(modules_to_save if modules_to_save is not None else [])
+            new_target_modules = list(target_modules) + list(
+                modules_to_save if modules_to_save is not None else []
+            )
 
             # Now check!
             new_target_modules = set(new_target_modules)
@@ -2375,8 +2690,11 @@ def get_peft_model(
             )
 
             check_all = check_all and (
-                (loftq_config == {} or loftq_config is None) and \
-                (peft_config["loftq_config"] == {} or peft_config["loftq_config"] is None)
+                (loftq_config == {} or loftq_config is None)
+                and (
+                    peft_config["loftq_config"] == {}
+                    or peft_config["loftq_config"] is None
+                )
             )
 
             if check_all:
@@ -2388,24 +2706,28 @@ def get_peft_model(
                 # Offload!
                 # [TODO] First offload lm_head and embed_tokens to CPU (should be disk!!)
                 if "embed_tokens" in new_target_modules:
-                    print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
+                    print(
+                        "Unsloth: Training embed_tokens in mixed precision to save VRAM"
+                    )
 
                     new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype
                     if new_dtype == torch.float16:
                         # See https://github.com/unslothai/unsloth/pull/1200
                         # Tesla T4 must use float32 and not float16
                         new_dtype = torch.float32
-                    pass
 
-                    model.get_input_embeddings().modules_to_save.default\
-                        .to(device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True)
-                    model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
+                    model.get_input_embeddings().modules_to_save.default.to(
+                        device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True
+                    )
+                    model.get_input_embeddings().modules_to_save.default.requires_grad_(
+                        True
+                    )
 
                     # [TODO] Move old embed_tokens to CPU - should be disk!
-                    model.get_input_embeddings().original_module\
-                        .to(device = "cpu", non_blocking = True)
+                    model.get_input_embeddings().original_module.to(
+                        device = "cpu", non_blocking = True
+                    )
                     model.get_input_embeddings().original_module.requires_grad_(False)
-                pass
 
                 if "lm_head" in new_target_modules:
                     print("Unsloth: Training lm_head in mixed precision to save VRAM")
@@ -2415,101 +2737,106 @@ def get_peft_model(
                         # See https://github.com/unslothai/unsloth/pull/1200
                         # Tesla T4 must use float32 and not float16
                         new_dtype = torch.float32
-                    pass
 
-                    model.get_output_embeddings().modules_to_save.default\
-                        .to(device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True)
-                    model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
+                    model.get_output_embeddings().modules_to_save.default.to(
+                        device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True
+                    )
+                    model.get_output_embeddings().modules_to_save.default.requires_grad_(
+                        True
+                    )
 
                     # [TODO] Move old lm_head to CPU - should be disk!
-                    model.get_output_embeddings().original_module\
-                        .to(device = "cpu", non_blocking = True)
+                    model.get_output_embeddings().original_module.to(
+                        device = "cpu", non_blocking = True
+                    )
                     model.get_output_embeddings().original_module.requires_grad_(False)
-                pass
 
                 return model
             else:
                 raise TypeError(
                     "Unsloth: Your model already has LoRA adapters. Your new parameters are different."
                 )
-            pass
-        pass
 
-        if loftq_config is None: loftq_config = {}
+        if loftq_config is None:
+            loftq_config = {}
 
         signature = str(inspect.signature(LoraConfig))
-        SUPPORTS_LOFTQ  = "loftq_config" in signature
-        SUPPORTS_RSLORA = "use_rslora"   in signature
+        SUPPORTS_LOFTQ = "loftq_config" in signature
+        SUPPORTS_RSLORA = "use_rslora" in signature
 
         if lora_dropout != 0:
             logger.warning_once(
-                f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"\
+                f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"
                 f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
             )
-        pass
 
         if bias != "none":
             logger.warning_once(
-                f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"\
+                f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"
                 f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
             )
-        pass
 
-        if not (type(init_lora_weights) is bool or \
-            init_lora_weights == "gaussian" or init_lora_weights == "loftq"):
+        if not (
+            type(init_lora_weights) is bool
+            or init_lora_weights == "gaussian"
+            or init_lora_weights == "loftq"
+        ):
             raise ValueError(
                 'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq"].'
             )
-        pass
 
         if init_lora_weights == "loftq":
-
             if not SUPPORTS_LOFTQ:
                 import peft
+
                 raise RuntimeError(
-                    f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"\
-                    "Please install PEFT 0.7.2 or higher.\n"\
+                    f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"
+                    "Please install PEFT 0.7.2 or higher.\n"
                     "You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
                 )
-            pass
 
             if loftq_config == {}:
                 from peft import LoftQConfig
+
                 logger.warning_once(
-                    "Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\
+                    "Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"
                     "We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
                 )
                 loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
-            pass
 
             if hasattr(model.config, "quantization_config"):
                 raise ValueError(
-                    "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\
+                    "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"
                     "Reload your model without any quantization by setting `load_in_4bit = False`."
                 )
-            pass
-        pass
 
-        assert(type(use_rslora) is bool)
+        assert type(use_rslora) is bool
         if use_rslora:
             if not SUPPORTS_RSLORA:
                 # We manually check for PEFT
                 import peft
+
                 raise RuntimeError(
-                    f"Unsloth: Your PEFT version of {peft.__version__} does not support `use_rslora`.\n"\
-                    "Please install PEFT 0.7.2 or higher.\n"\
+                    f"Unsloth: Your PEFT version of {peft.__version__} does not support `use_rslora`.\n"
+                    "Please install PEFT 0.7.2 or higher.\n"
                     "You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
                 )
-            pass
-        pass
 
-        accepted_modules = frozenset(("q_proj", "k_proj", "v_proj", "o_proj",
-                                      "gate_proj", "up_proj", "down_proj",),)
-        model.config.update({"unsloth_version" : __version__})
+        accepted_modules = frozenset(
+            (
+                "q_proj",
+                "k_proj",
+                "v_proj",
+                "o_proj",
+                "gate_proj",
+                "up_proj",
+                "down_proj",
+            ),
+        )
+        model.config.update({"unsloth_version": __version__})
 
         if type(modules_to_save) is tuple:
             modules_to_save = list(modules_to_save)
-        pass
 
         train_lm_head = False
         train_embed_tokens = False
@@ -2521,8 +2848,10 @@ def get_peft_model(
                 #     "Luckily, we shall do it for you!"
                 # )
                 train_lm_head = True
-                if modules_to_save is None: modules_to_save = ["lm_head"]
-                else: modules_to_save.append("lm_head")
+                if modules_to_save is None:
+                    modules_to_save = ["lm_head"]
+                else:
+                    modules_to_save.append("lm_head")
 
             elif module == "embed_tokens":
                 # logger.warning_once(
@@ -2530,40 +2859,41 @@ def get_peft_model(
                 #     "Luckily, we shall do it for you!"
                 # )
                 train_embed_tokens = True
-                if modules_to_save is None: modules_to_save = ["embed_tokens"]
-                else: modules_to_save.append("embed_tokens")
+                if modules_to_save is None:
+                    modules_to_save = ["embed_tokens"]
+                else:
+                    modules_to_save.append("embed_tokens")
 
             else:
                 try:
-                    assert(module in accepted_modules)
+                    assert module in accepted_modules
                     final_modules.append(module)
                 except AssertionError as e:
                     final_modules.append(module)
                     print(
-                        "Unsloth: You added custom modules, but Unsloth hasn't optimized for this.\n"\
+                        "Unsloth: You added custom modules, but Unsloth hasn't optimized for this.\n"
                         "Beware - your finetuning might be noticeably slower!"
                     )
-                pass
-            pass
-        pass
 
         # Check if we added new tokens!
         if hasattr(model, "_need_to_train_embeddings"):
             if not train_lm_head or not train_embed_tokens:
                 print(
-                    "Unsloth: You added new tokens but did not specify if you wanted to "\
+                    "Unsloth: You added new tokens but did not specify if you wanted to "
                     "train the lm_head and embed_tokens.\nWe must turn it on for you."
                 )
                 train_lm_head = True
                 train_embed_tokens = True
 
-                if modules_to_save is None: modules_to_save = ["embed_tokens"]
-                else: modules_to_save.append("embed_tokens")
+                if modules_to_save is None:
+                    modules_to_save = ["embed_tokens"]
+                else:
+                    modules_to_save.append("embed_tokens")
 
-                if modules_to_save is None: modules_to_save = ["lm_head"]
-                else: modules_to_save.append("lm_head")
-            pass
-        pass
+                if modules_to_save is None:
+                    modules_to_save = ["lm_head"]
+                else:
+                    modules_to_save.append("lm_head")
 
         # Check for Llama-3
         # if hasattr(model._saved_temp_tokenizer, "_using_llama3_template"):
@@ -2587,11 +2917,8 @@ def get_peft_model(
                     raise TypeError(
                         f"Unsloth: Module = {module} is not allowed. Only 'lm_head' and 'embed_tokens' is allowed."
                     )
-            pass
-        pass
         if isinstance(modules_to_save, (tuple, list)):
             modules_to_save = list(set(modules_to_save))
-        pass
 
         vllm_engine = None
         if hasattr(model, "vllm_engine"):
@@ -2601,39 +2928,44 @@ def get_peft_model(
             vllm_fast_generate_batches = model.fast_generate_batches
 
             if modules_to_save is not None:
-                raise NotImplementedError("Unsloth: Currently fast inference does not work with training embeddings or lm_head.")
+                raise NotImplementedError(
+                    "Unsloth: Currently fast inference does not work with training embeddings or lm_head."
+                )
 
             if bias != "none":
-                raise NotImplementedError("Unsloth: Currently fast inference does not work with using biases for LoRA.")
-        pass
+                raise NotImplementedError(
+                    "Unsloth: Currently fast inference does not work with using biases for LoRA."
+                )
 
         # Does not get lora yet, so get name from model, not base model
         is_classification = "Classification" in str(type(model))
 
         arguments = dict(
-            r                   = r,
-            lora_alpha          = lora_alpha,
-            target_modules      = final_modules,
-            lora_dropout        = lora_dropout,
-            bias                = bias,
-            task_type           = TaskType.CAUSAL_LM if not is_classification else TaskType.SEQ_CLS,
+            r = r,
+            lora_alpha = lora_alpha,
+            target_modules = final_modules,
+            lora_dropout = lora_dropout,
+            bias = bias,
+            task_type = TaskType.CAUSAL_LM if not is_classification else TaskType.SEQ_CLS,
             layers_to_transform = layers_to_transform,
-            init_lora_weights   = init_lora_weights,
-            loftq_config        = loftq_config,
-            use_rslora          = use_rslora,
-            modules_to_save     = modules_to_save,
+            init_lora_weights = init_lora_weights,
+            loftq_config = loftq_config,
+            use_rslora = use_rslora,
+            modules_to_save = modules_to_save,
             **kwargs,
         )
-        if not SUPPORTS_LOFTQ:  del arguments["loftq_config"]
-        if not SUPPORTS_RSLORA: del arguments["use_rslora"]
+        if not SUPPORTS_LOFTQ:
+            del arguments["loftq_config"]
+        if not SUPPORTS_RSLORA:
+            del arguments["use_rslora"]
 
         _saved_temp_tokenizer = model._saved_temp_tokenizer
 
         lora_config = LoraConfig(**arguments)
         # First offload lm_head and embed_tokens to disk
-        input_embeddings_device  = model.get_input_embeddings().weight.device
+        input_embeddings_device = model.get_input_embeddings().weight.device
         if is_classification:
-             output_embeddings_device = model.score.weight.device
+            output_embeddings_device = model.score.weight.device
         else:
             output_embeddings_device = model.get_output_embeddings().weight.device
 
@@ -2641,25 +2973,20 @@ def get_peft_model(
             if train_embed_tokens:
                 print("Unsloth: Offloading input_embeddings to disk to save VRAM")
                 offload_input_embeddings(model, temporary_location)
-            pass
 
             # Remove old items to save VRAM
             for _ in range(3):
                 gc.collect()
                 clean_gpu_cache()
-            pass
 
             if train_lm_head:
                 print("Unsloth: Offloading output_embeddings to disk to save VRAM")
                 offload_output_embeddings(model, temporary_location)
-            pass
 
             # Remove old items to save VRAM
             for _ in range(3):
                 gc.collect()
                 clean_gpu_cache()
-            pass
-        pass
 
         model = _get_peft_model(model, lora_config)
         # Fix LoraConfig.auto_mapping is None
@@ -2669,7 +2996,6 @@ def get_peft_model(
         if qat_scheme is not None:
             print("Unsloth: Applying QAT to mitigate quantization degradation")
             model = _prepare_model_for_qat(model, qat_scheme)
-        pass
 
         model._saved_temp_tokenizer = _saved_temp_tokenizer
 
@@ -2677,49 +3003,48 @@ def get_peft_model(
 
         if train_embed_tokens:
             print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
-            assert(hasattr(model.get_input_embeddings(), "modules_to_save"))
+            assert hasattr(model.get_input_embeddings(), "modules_to_save")
 
-            new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype
+            new_dtype = (
+                model.get_input_embeddings().modules_to_save.default.weight.dtype
+            )
             if new_dtype == torch.float16:
                 # See https://github.com/unslothai/unsloth/pull/1200
                 # Tesla T4 must use float32 and not float16
                 new_dtype = torch.float32
-            pass
 
-            model.get_input_embeddings().modules_to_save.default\
-                .to(device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True)
+            model.get_input_embeddings().modules_to_save.default.to(
+                device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True
+            )
             model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
-        pass
 
         if train_lm_head:
             print("Unsloth: Training lm_head in mixed precision to save VRAM")
-            assert(hasattr(model.get_output_embeddings(), "modules_to_save"))
+            assert hasattr(model.get_output_embeddings(), "modules_to_save")
 
-            new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype
+            new_dtype = (
+                model.get_output_embeddings().modules_to_save.default.weight.dtype
+            )
             if new_dtype == torch.float16:
                 # See https://github.com/unslothai/unsloth/pull/1200
                 # Tesla T4 must use float32 and not float16
                 new_dtype = torch.float32
-            pass
 
-            model.get_output_embeddings().modules_to_save.default\
-                .to(device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True)
+            model.get_output_embeddings().modules_to_save.default.to(
+                device = DEVICE_TYPE_TORCH, dtype = new_dtype, non_blocking = True
+            )
             model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
-        pass
 
         # Patch tokenizer to pad to the right
         internal_model = model
         while hasattr(internal_model, "model"):
             if hasattr(internal_model, "_saved_temp_tokenizer"):
                 internal_model._saved_temp_tokenizer.padding_side = "right"
-            pass
             # Also set is_loaded_in_8bit to disable incorrect DDP
             internal_model.is_loaded_in_8bit = True
             internal_model = internal_model.model
-        pass
         if hasattr(internal_model, "_saved_temp_tokenizer"):
             internal_model._saved_temp_tokenizer.padding_side = "right"
-        pass
         # Also set is_loaded_in_8bit to disable incorrect DDP
         internal_model.is_loaded_in_8bit = True
 
@@ -2727,21 +3052,18 @@ def get_peft_model(
         for _ in range(3):
             gc.collect()
             clean_gpu_cache()
-        pass
 
         patch_peft_fast_inference(model)
 
         # Add for_inference and for_training
-        model.for_training  = functools.partial(FastLlamaModel.for_training,  model)
+        model.for_training = functools.partial(FastLlamaModel.for_training, model)
         model.for_inference = functools.partial(FastLlamaModel.for_inference, model)
         m = model
         while hasattr(m, "model"):
-            m.for_training  = functools.partial(FastBaseModel.for_training,  m)
+            m.for_training = functools.partial(FastBaseModel.for_training, m)
             m.for_inference = functools.partial(FastBaseModel.for_inference, m)
             m = m.model
         return model
-    pass
-
 
     @staticmethod
     def patch_peft_model(
@@ -2753,29 +3075,38 @@ def patch_peft_model(
                 model = model,
                 use_gradient_checkpointing = use_gradient_checkpointing,
             )
-        pass
-        if not isinstance(model, PeftModelForCausalLM) and not isinstance(model, PeftModelForSequenceClassification):
+        if not isinstance(model, PeftModelForCausalLM) and not isinstance(
+            model, PeftModelForSequenceClassification
+        ):
             raise TypeError(
                 "Unsloth: Your model needs to call `.get_peft_model` first!"
             )
-        pass
 
         # Get activation function
         model_type = model.config.model_type
 
-        if   model_type == "llama":     apply_lora_mlp = apply_lora_mlp_swiglu
-        elif model_type == "mistral":   apply_lora_mlp = apply_lora_mlp_swiglu
-        elif model_type == "qwen2":     apply_lora_mlp = apply_lora_mlp_swiglu
-        elif model_type == "gemma":     apply_lora_mlp = apply_lora_mlp_geglu_approx
-        elif model_type == "gemma2":    apply_lora_mlp = apply_lora_mlp_geglu_approx
-        elif model_type == "cohere":    apply_lora_mlp = apply_lora_mlp_swiglu
-        elif model_type == "granite":   apply_lora_mlp = apply_lora_mlp_swiglu
-        elif model_type == "qwen3":     apply_lora_mlp = apply_lora_mlp_swiglu
-        elif model_type == "falcon_h1": apply_lora_mlp = apply_lora_mlp_swiglu
-        elif model_type == "qwen3moe":  apply_lora_mlp = apply_lora_mlp_swiglu
+        if model_type == "llama":
+            apply_lora_mlp = apply_lora_mlp_swiglu
+        elif model_type == "mistral":
+            apply_lora_mlp = apply_lora_mlp_swiglu
+        elif model_type == "qwen2":
+            apply_lora_mlp = apply_lora_mlp_swiglu
+        elif model_type == "gemma":
+            apply_lora_mlp = apply_lora_mlp_geglu_approx
+        elif model_type == "gemma2":
+            apply_lora_mlp = apply_lora_mlp_geglu_approx
+        elif model_type == "cohere":
+            apply_lora_mlp = apply_lora_mlp_swiglu
+        elif model_type == "granite":
+            apply_lora_mlp = apply_lora_mlp_swiglu
+        elif model_type == "qwen3":
+            apply_lora_mlp = apply_lora_mlp_swiglu
+        elif model_type == "falcon_h1":
+            apply_lora_mlp = apply_lora_mlp_swiglu
+        elif model_type == "qwen3moe":
+            apply_lora_mlp = apply_lora_mlp_swiglu
         else:
             raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!")
-        pass
 
         model = prepare_model_for_kbit_training(
             model,
@@ -2789,53 +3120,56 @@ def patch_peft_model(
             if False:
                 name = model.peft_config[active_adapter].base_model_name_or_path
                 if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
-                    name = name[:len(name) - len("-bnb-4bit")]
+                    name = name[: len(name) - len("-bnb-4bit")]
                     model.peft_config[active_adapter].base_model_name_or_path = name
-                pass
             # Add revision to enable future fast inference paths
             # [TODO] Bugs out!see https://github.com/unslothai/unsloth/issues/492
             # model.peft_config[active_adapter].revision = f"unsloth"
-        pass
 
         from transformers.trainer import Trainer
+
         if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
-            raise RuntimeError("Unsloth: Unsuccessfully patched Trainer! Please file a bug report!")
-        pass
+            raise RuntimeError(
+                "Unsloth: Unsuccessfully patched Trainer! Please file a bug report!"
+            )
 
         # Fix loftq issues
         # loftq_config must not = None, but rather {}
         all_configs = model.peft_config
         for key, current_config in all_configs.items():
-            if hasattr(current_config, "loftq_config") and current_config.loftq_config is None:
+            if (
+                hasattr(current_config, "loftq_config")
+                and current_config.loftq_config is None
+            ):
                 new_args = current_config.__dict__
                 new_args["loftq_config"] = {}
                 current_config = current_config.__class__(**new_args)
                 all_configs[key] = current_config
-            pass
-        pass
 
         # Do patching
         n_mlp = 0
         n_qkv = 0
-        n_o   = 0
+        n_o = 0
 
-        active_adapter = model.active_adapters[0] if \
-            hasattr(model, "active_adapters") else model.active_adapter
+        active_adapter = (
+            model.active_adapters[0]
+            if hasattr(model, "active_adapters")
+            else model.active_adapter
+        )
 
         # Get dropout and bias
         lora_dropout = model.peft_config[active_adapter].lora_dropout
-        bias         = model.peft_config[active_adapter].bias
+        bias = model.peft_config[active_adapter].bias
 
         # We also do not inplace edit QKV for Cohere!
-        _apply_lora_mlp = \
-            functools.partial(apply_lora_mlp, inplace = False) \
-            if model_type == "cohere" else \
-            apply_lora_mlp
-        pass
+        _apply_lora_mlp = (
+            functools.partial(apply_lora_mlp, inplace = False)
+            if model_type == "cohere"
+            else apply_lora_mlp
+        )
 
         if lora_dropout == 0 and bias == "none":
             for idx, layer in enumerate(model.model.model.layers):
-
                 if model_type != "falcon_h1":
                     # LoRAMLP.apply doesn't have functionality for gate and down multipliers yet.
                     # Don't patch falcon h1 for the time being.
@@ -2843,75 +3177,83 @@ def patch_peft_model(
                     # MLP patching
                     mlp_module = layer.mlp
                     gate_proj = mlp_module.gate_proj
-                    up_proj   = mlp_module.  up_proj
+                    up_proj = mlp_module.up_proj
                     down_proj = mlp_module.down_proj
 
-                    if hasattr(gate_proj, "lora_A") and \
-                        hasattr(  up_proj, "lora_A") and \
-                        hasattr(down_proj, "lora_A") and \
-                        (getattr(gate_proj, "base_layer", gate_proj).bias is None) and \
-                        (getattr(  up_proj, "base_layer",   up_proj).bias is None) and \
-                        (getattr(down_proj, "base_layer", down_proj).bias is None) and \
-                        (len(getattr(gate_proj, "lora_magnitude_vector", []) or []) == 0) and \
-                        (len(getattr(  up_proj, "lora_magnitude_vector", []) or []) == 0) and \
-                        (len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0):
-
+                    if (
+                        hasattr(gate_proj, "lora_A")
+                        and hasattr(up_proj, "lora_A")
+                        and hasattr(down_proj, "lora_A")
+                        and (getattr(gate_proj, "base_layer", gate_proj).bias is None)
+                        and (getattr(up_proj, "base_layer", up_proj).bias is None)
+                        and (getattr(down_proj, "base_layer", down_proj).bias is None)
+                        and (
+                            len(getattr(gate_proj, "lora_magnitude_vector", []) or [])
+                            == 0
+                        )
+                        and (
+                            len(getattr(up_proj, "lora_magnitude_vector", []) or [])
+                            == 0
+                        )
+                        and (
+                            len(getattr(down_proj, "lora_magnitude_vector", []) or [])
+                            == 0
+                        )
+                    ):
                         # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
-                        mlp_module.forward = types.MethodType(_apply_lora_mlp, mlp_module)
+                        mlp_module.forward = types.MethodType(
+                            _apply_lora_mlp, mlp_module
+                        )
                         n_mlp += 1
                     else:
                         logger.warning_once(
-                            "Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters\n"\
+                            "Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters\n"
                             "are not enabled or a bias term (like in Qwen) is used."
                         )
-                    pass
-                pass
 
                 # QKV attention patching
                 q_proj = layer.self_attn.q_proj
                 k_proj = layer.self_attn.k_proj
                 v_proj = layer.self_attn.v_proj
-                if  hasattr(q_proj, "lora_A") and \
-                    hasattr(k_proj, "lora_A") and \
-                    hasattr(v_proj, "lora_A") and \
-                    (getattr(q_proj, "base_layer", q_proj).bias is None) and \
-                    (getattr(k_proj, "base_layer", k_proj).bias is None) and \
-                    (getattr(v_proj, "base_layer", v_proj).bias is None) and \
-                    (len(getattr(q_proj, "lora_magnitude_vector", []) or []) == 0) and \
-                    (len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0) and \
-                    (len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0):
-
+                if (
+                    hasattr(q_proj, "lora_A")
+                    and hasattr(k_proj, "lora_A")
+                    and hasattr(v_proj, "lora_A")
+                    and (getattr(q_proj, "base_layer", q_proj).bias is None)
+                    and (getattr(k_proj, "base_layer", k_proj).bias is None)
+                    and (getattr(v_proj, "base_layer", v_proj).bias is None)
+                    and (len(getattr(q_proj, "lora_magnitude_vector", []) or []) == 0)
+                    and (len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0)
+                    and (len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0)
+                ):
                     layer.self_attn.apply_qkv = apply_lora_qkv
                     n_qkv += 1
                 else:
-                    if model_type == "qwen2": n_qkv += 1
+                    if model_type == "qwen2":
+                        n_qkv += 1
                     else:
                         logger.warning_once(
-                            "Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"\
+                            "Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"
                             "are not enabled or a bias term (like in Qwen) is used."
                         )
-                    pass
-                pass
 
                 # O attention patching
                 o_proj = layer.self_attn.o_proj
-                if hasattr(o_proj, "lora_A") and \
-                    (getattr(o_proj, "base_layer", o_proj).bias is None) and \
-                    (len(getattr(o_proj, "lora_magnitude_vector", []) or []) == 0):
-
+                if (
+                    hasattr(o_proj, "lora_A")
+                    and (getattr(o_proj, "base_layer", o_proj).bias is None)
+                    and (len(getattr(o_proj, "lora_magnitude_vector", []) or []) == 0)
+                ):
                     layer.self_attn.apply_o = apply_lora_o
                     n_o += 1
                 else:
                     logger.warning_once(
-                        "Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters\n"\
+                        "Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters\n"
                         "are not enabled or a bias term (like in Qwen) is used."
                     )
-                pass
-            pass
-        pass
 
         logger.warning_once(
-            f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "\
+            f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "
             f"{n_qkv} QKV layers, {n_o} O layers and {n_mlp} MLP layers.",
         )
         patch_saving_functions(model)
@@ -2925,7 +3267,6 @@ def patch_peft_model(
         while hasattr(internal_model, "model"):
             internal_model.max_seq_length = max_seq_length
             internal_model = internal_model.model
-        pass
         internal_model.max_seq_length = max_seq_length
         # Save to modules as well
         for module in model.modules():
@@ -2936,116 +3277,116 @@ def patch_peft_model(
         while hasattr(internal_model, "model"):
             if hasattr(internal_model, "_saved_temp_tokenizer"):
                 internal_model._saved_temp_tokenizer.padding_side = "right"
-            pass
             internal_model = internal_model.model
-        pass
         if hasattr(internal_model, "_saved_temp_tokenizer"):
             internal_model._saved_temp_tokenizer.padding_side = "right"
-        pass
 
         # Clear deleted GPU items
         for _ in range(3):
             gc.collect()
             clean_gpu_cache()
-        pass
 
         patch_peft_fast_inference(model)
 
         # Add for_inference and for_training
-        model.for_training  = functools.partial(FastLlamaModel.for_training,  model)
+        model.for_training = functools.partial(FastLlamaModel.for_training, model)
         model.for_inference = functools.partial(FastLlamaModel.for_inference, model)
         m = model
         while hasattr(m, "model"):
-            m.for_training  = functools.partial(FastBaseModel.for_training,  m)
+            m.for_training = functools.partial(FastBaseModel.for_training, m)
             m.for_inference = functools.partial(FastBaseModel.for_inference, m)
             m = m.model
         return model
-    pass
-
 
     @staticmethod
     def for_inference(model):
         if not hasattr(model, "parameters"):
-            raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_inference!")
+            raise TypeError(
+                "Unsloth: I think you're passing a tokenizer, not the model to for_inference!"
+            )
 
         def _for_inference(m):
-            if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = False
-            if hasattr(m, "training"): m.training = False
+            if hasattr(m, "gradient_checkpointing"):
+                m.gradient_checkpointing = False
+            if hasattr(m, "training"):
+                m.training = False
             # Pad tokenizer to the left
-            if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "left"
+            if hasattr(m, "_saved_temp_tokenizer"):
+                m._saved_temp_tokenizer.padding_side = "left"
             # Set a flag for generation!
             m._flag_for_generation = True
-        pass
+
         m = model
         while hasattr(m, "model"):
             _for_inference(m)
             m = m.model
         _for_inference(m)
-        model.eval() # to turn off training on modules deeper in
+        model.eval()  # to turn off training on modules deeper in
 
         # Since transformers 4.53, must turn off explicitly
         for module in model.modules():
             if hasattr(module, "gradient_checkpointing"):
                 module.gradient_checkpointing = False
-        pass
 
         # Also disable training for embeddings for NEFTune
         if hasattr(model, "get_input_embeddings"):
             embeddings = model.get_input_embeddings()
-            if hasattr(embeddings, "training"): embeddings.training = False
-        pass
+            if hasattr(embeddings, "training"):
+                embeddings.training = False
         if hasattr(model, "get_output_embeddings"):
             embeddings = model.get_output_embeddings()
-            if hasattr(embeddings, "training"): embeddings.training = False
-        pass
+            if hasattr(embeddings, "training"):
+                embeddings.training = False
         return model
-    pass
-
 
     @staticmethod
     def for_training(model, use_gradient_checkpointing = True):
         if not hasattr(model, "parameters"):
-            raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_training!")
+            raise TypeError(
+                "Unsloth: I think you're passing a tokenizer, not the model to for_training!"
+            )
 
         # Delete all fast inference loras
         for param in model.parameters():
             if hasattr(param, "_fast_lora"):
                 del param._fast_lora
-        pass
 
         def _for_training(m):
-            if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = use_gradient_checkpointing
-            if hasattr(m, "training"): m.training = True
+            if hasattr(m, "gradient_checkpointing"):
+                m.gradient_checkpointing = use_gradient_checkpointing
+            if hasattr(m, "training"):
+                m.training = True
             # Pad tokenizer to the left
-            if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "right"
+            if hasattr(m, "_saved_temp_tokenizer"):
+                m._saved_temp_tokenizer.padding_side = "right"
             # Set a flag for generation!
-            if hasattr(m, "_flag_for_generation"): del m._flag_for_generation
-        pass
+            if hasattr(m, "_flag_for_generation"):
+                del m._flag_for_generation
+
         m = model
         while hasattr(m, "model"):
             _for_training(m)
             m = m.model
         _for_training(m)
-        model.train() # to turn on training on modules deeper in
+        model.train()  # to turn on training on modules deeper in
 
         # Since transformers 4.53, must turn on explicitly
         for module in model.modules():
             if hasattr(module, "gradient_checkpointing"):
                 module.gradient_checkpointing = use_gradient_checkpointing
-        pass
 
         # Also re-enable training for embeddings for NEFTune
         if hasattr(model, "get_input_embeddings"):
             embeddings = model.get_input_embeddings()
-            if hasattr(embeddings, "training"): embeddings.training = True
-        pass
+            if hasattr(embeddings, "training"):
+                embeddings.training = True
         if hasattr(model, "get_output_embeddings"):
             embeddings = model.get_output_embeddings()
-            if hasattr(embeddings, "training"): embeddings.training = True
-        pass
+            if hasattr(embeddings, "training"):
+                embeddings.training = True
         return model
-    pass
-pass
+
 
 from .rl import PatchFastRL
+
 PatchFastRL(FastLanguageModel = FastLlamaModel)
diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py
index 5f92999f4..04d597280 100644
--- a/unsloth/models/loader.py
+++ b/unsloth/models/loader.py
@@ -22,17 +22,18 @@
     get_transformers_model_type,
 )
 from .granite import FastGraniteModel
-from .llama   import FastLlamaModel, logger
+from .llama import FastLlamaModel, logger
 from .mistral import FastMistralModel
-from .qwen2   import FastQwen2Model
-from .qwen3   import FastQwen3Model
+from .qwen2 import FastQwen2Model
+from .qwen3 import FastQwen3Model
 from .qwen3_moe import FastQwen3MoeModel
-from .cohere  import FastCohereModel
+from .cohere import FastCohereModel
 from transformers import AutoConfig
 from transformers import __version__ as transformers_version
 from peft import PeftConfig, PeftModel
 from .loader_utils import get_model_name
 import os, contextlib, sys
+
 try:
     from huggingface_hub import get_token
 except:
@@ -41,8 +42,6 @@
     except:
         # For older versions of huggingface_hub
         from huggingface_hub.utils._token import get_token
-    pass
-pass
 from huggingface_hub import HfFileSystem
 import importlib.util
 from ..device_type import (
@@ -58,26 +57,25 @@
 # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
 from unsloth_zoo.utils import Version, _get_dtype
 from unsloth_zoo.hf_utils import dtype_from_config
+
 transformers_version = Version(transformers_version)
-SUPPORTS_FOURBIT   = transformers_version >= Version("4.37")
-SUPPORTS_GEMMA     = transformers_version >= Version("4.38")
-SUPPORTS_GEMMA2    = transformers_version >= Version("4.42")
-SUPPORTS_LLAMA31   = transformers_version >= Version("4.43.2")
-SUPPORTS_LLAMA32   = transformers_version  > Version("4.45.0")
-SUPPORTS_GRANITE   = transformers_version >= Version("4.46.0")
-SUPPORTS_QWEN3     = transformers_version >= Version("4.50.3")
+SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
+SUPPORTS_GEMMA = transformers_version >= Version("4.38")
+SUPPORTS_GEMMA2 = transformers_version >= Version("4.42")
+SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.2")
+SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0")
+SUPPORTS_GRANITE = transformers_version >= Version("4.46.0")
+SUPPORTS_QWEN3 = transformers_version >= Version("4.50.3")
 SUPPORTS_QWEN3_MOE = transformers_version >= Version("4.50.3")
 SUPPORTS_FALCON_H1 = transformers_version >= Version("4.53.0")
-SUPPORTS_GEMMA3N   = transformers_version >= Version("4.53.0")
-SUPPORTS_GPTOSS    = transformers_version >= Version("4.55.0")
+SUPPORTS_GEMMA3N = transformers_version >= Version("4.53.0")
+SUPPORTS_GPTOSS = transformers_version >= Version("4.55.0")
 if SUPPORTS_GEMMA:
-    from .gemma  import FastGemmaModel
+    from .gemma import FastGemmaModel
 if SUPPORTS_GEMMA2:
     from .gemma2 import FastGemma2Model
-pass
 if SUPPORTS_FALCON_H1:
     from .falcon_h1 import FastFalconH1Model
-pass
 import torch
 from ._utils import (
     patch_compiling_bitsandbytes,
@@ -93,7 +91,7 @@
 global FORCE_FLOAT32
 # Forces float32 precision since float16 goes to infinity
 FORCE_FLOAT32 = [
-    "gemma3,", # Add comma bc gemma3 will match gemma3n
+    "gemma3,",  # Add comma bc gemma3 will match gemma3n
     "gemma3n",
     "gpt_oss",
 ]
@@ -103,107 +101,112 @@
 DISABLE_COMPILE_MODEL_NAMES = [
     "aya_vision",
     "modernbert",
-    "granite,llava_next", # Granite-vision 3
+    "granite,llava_next",  # Granite-vision 3
 ]
 
 global DISABLE_SDPA_MODEL_NAMES
 # Disables some SDPA modules since it's wrong
 DISABLE_SDPA_MODEL_NAMES = [
-    "gemma3,", # Add comma bc gemma3 will match gemma3n
+    "gemma3,",  # Add comma bc gemma3 will match gemma3n
 ]
 
 
 class FastLanguageModel(FastLlamaModel):
     @staticmethod
     def from_pretrained(
-        model_name                 = "unsloth/Llama-3.2-1B-Instruct",
-        max_seq_length             = 2048,
-        dtype                      = None,
-        load_in_4bit               = True,  # 4bit QLoRA
-        load_in_8bit               = False, # 8bit  LoRA
-        load_in_16bit              = False, # 16bit LoRA
-        full_finetuning            = False,
-        token                      = None,
-        device_map                 = "sequential",
-        rope_scaling               = None,
-        fix_tokenizer              = True,
-        trust_remote_code          = False,
+        model_name = "unsloth/Llama-3.2-1B-Instruct",
+        max_seq_length = 2048,
+        dtype = None,
+        load_in_4bit = True,  # 4bit QLoRA
+        load_in_8bit = False,  # 8bit  LoRA
+        load_in_16bit = False,  # 16bit LoRA
+        full_finetuning = False,
+        token = None,
+        device_map = "sequential",
+        rope_scaling = None,
+        fix_tokenizer = True,
+        trust_remote_code = False,
         use_gradient_checkpointing = "unsloth",
-        resize_model_vocab         = None,
-        revision                   = None,
-        use_exact_model_name       = False,
-        offload_embedding          = False,
-
-        fast_inference             = False, # uses vLLM
-        gpu_memory_utilization     = 0.5,
-        float8_kv_cache            = False,
-        random_state               = 3407,
-        max_lora_rank              = 64,
-        disable_log_stats          = True,
-        qat_scheme                 = None,
-        *args, **kwargs,
+        resize_model_vocab = None,
+        revision = None,
+        use_exact_model_name = False,
+        offload_embedding = False,
+        fast_inference = False,  # uses vLLM
+        gpu_memory_utilization = 0.5,
+        float8_kv_cache = False,
+        random_state = 3407,
+        max_lora_rank = 64,
+        disable_log_stats = True,
+        qat_scheme = None,
+        *args,
+        **kwargs,
     ):
         # Login to allow private models
-        if token is None: token = get_token()
+        if token is None:
+            token = get_token()
         if token is not None:
             try:
                 from huggingface_hub import login
+
                 login(token = token)
             except:
                 pass
         if load_in_8bit or full_finetuning or qat_scheme is not None:
             return FastModel.from_pretrained(
-                model_name                 = model_name,
-                max_seq_length             = max_seq_length,
-                dtype                      = dtype,
-                load_in_4bit               = load_in_4bit,
-                load_in_8bit               = load_in_8bit,
-                load_in_16bit              = load_in_16bit,
-                full_finetuning            = full_finetuning,
-                token                      = token,
-                device_map                 = device_map,
-                rope_scaling               = rope_scaling, # [TODO] No effect
-                fix_tokenizer              = fix_tokenizer, # [TODO] No effect
-                trust_remote_code          = trust_remote_code,
+                model_name = model_name,
+                max_seq_length = max_seq_length,
+                dtype = dtype,
+                load_in_4bit = load_in_4bit,
+                load_in_8bit = load_in_8bit,
+                load_in_16bit = load_in_16bit,
+                full_finetuning = full_finetuning,
+                token = token,
+                device_map = device_map,
+                rope_scaling = rope_scaling,  # [TODO] No effect
+                fix_tokenizer = fix_tokenizer,  # [TODO] No effect
+                trust_remote_code = trust_remote_code,
                 use_gradient_checkpointing = use_gradient_checkpointing,
-                resize_model_vocab         = resize_model_vocab, # [TODO] No effect
-                revision                   = revision,
-                return_logits              = False, # Return logits
-                fullgraph                  = True, # No graph breaks
-                use_exact_model_name       = use_exact_model_name,
-                offload_embedding          = offload_embedding,
-
+                resize_model_vocab = resize_model_vocab,  # [TODO] No effect
+                revision = revision,
+                return_logits = False,  # Return logits
+                fullgraph = True,  # No graph breaks
+                use_exact_model_name = use_exact_model_name,
+                offload_embedding = offload_embedding,
                 # Pass vLLM/inference parameters
-                fast_inference             = fast_inference,
-                gpu_memory_utilization     = gpu_memory_utilization,
-                float8_kv_cache            = float8_kv_cache,
-                random_state               = random_state,
-                max_lora_rank              = max_lora_rank,
-                disable_log_stats          = disable_log_stats,
-
-                qat_scheme                 = qat_scheme,
-                *args, **kwargs,
+                fast_inference = fast_inference,
+                gpu_memory_utilization = gpu_memory_utilization,
+                float8_kv_cache = float8_kv_cache,
+                random_state = random_state,
+                max_lora_rank = max_lora_rank,
+                disable_log_stats = disable_log_stats,
+                qat_scheme = qat_scheme,
+                *args,
+                **kwargs,
             )
-        pass
 
-        if token is None: token = get_token()
+        if token is None:
+            token = get_token()
         if isinstance(dtype, str) and dtype in ["float16", "bfloat16"]:
             dtype = getattr(torch, dtype)
-        assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16
-                or dtype == torch.float32)
+        assert (
+            dtype is None
+            or dtype == torch.float16
+            or dtype == torch.bfloat16
+            or dtype == torch.float32
+        )
 
         if fast_inference:
             if importlib.util.find_spec("vllm") is None:
                 raise ImportError(
-                    "Unsloth: Please install vLLM before enabling `fast_inference`!\n"\
+                    "Unsloth: Please install vLLM before enabling `fast_inference`!\n"
                     "You can do this in a terminal via `pip install vllm`"
                 )
-            pass
-        pass
         # Check if 4bit is allowed specifically for AMD
         if not ALLOW_BITSANDBYTES and not use_exact_model_name:
             if load_in_4bit or load_in_8bit or model_name.lower().endswith("-bnb-4bit"):
-                print("Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.")
+                print(
+                    "Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now."
+                )
             load_in_4bit = False
 
         old_model_name = model_name
@@ -211,7 +214,9 @@ def from_pretrained(
             model_name = get_model_name(model_name, load_in_4bit)
         # Check if pre-quantized models are allowed
         # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64
-        if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(("-unsloth-bnb-4bit", "-bnb-4bit")):
+        if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
+            ("-unsloth-bnb-4bit", "-bnb-4bit")
+        ):
             model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
             model_name = model_name.lower().removesuffix("-bnb-4bit")
         # Change -BF16 to all False for 4bit, 8bit etc
@@ -222,11 +227,16 @@ def from_pretrained(
 
         if USE_MODELSCOPE and not os.path.exists(model_name):
             from modelscope import snapshot_download
+
             model_name = snapshot_download(model_name)
-        pass
 
         # First check if it's a normal model via AutoConfig
-        from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
+        from huggingface_hub.utils import (
+            disable_progress_bars,
+            enable_progress_bars,
+            are_progress_bars_disabled,
+        )
+
         was_disabled = are_progress_bars_disabled()
         disable_progress_bars()
 
@@ -246,7 +256,7 @@ def from_pretrained(
             autoconfig_error = str(error)
             if "architecture" in autoconfig_error:
                 raise ValueError(
-                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"\
+                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
                     f"Please update transformers via `pip install --upgrade transformers` and try again."
                 )
             is_model = False
@@ -262,11 +272,10 @@ def from_pretrained(
             peft_error = str(error)
             if "architecture" in peft_error:
                 raise ValueError(
-                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"\
+                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
                     f"Please update transformers via `pip install --upgrade transformers` and try again."
                 )
             is_peft = False
-        pass
 
         # Old transformers versions check
         both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32
@@ -274,9 +283,9 @@ def from_pretrained(
         # Error out if both LoRA and normal model config exists.
         if both_exist:
             raise RuntimeError(
-                "Unsloth: Your repo has a LoRA adapter and a base model.\n"\
-                "You have 2 files `config.json` and `adapter_config.json`.\n"\
-                "We must only allow one config file.\n"\
+                "Unsloth: Your repo has a LoRA adapter and a base model.\n"
+                "You have 2 files `config.json` and `adapter_config.json`.\n"
+                "We must only allow one config file.\n"
                 "Please separate the LoRA and base models to 2 repos."
             )
         model_types = get_transformers_model_type(
@@ -292,28 +301,30 @@ def from_pretrained(
         if SUPPORTS_LLAMA32:
             # Check if folder exists locally
             if os.path.isdir(model_name):
-                exist_adapter_config = os.path.exists(os.path.join(model_name, "adapter_config.json"))
-                exist_config         = os.path.exists(os.path.join(model_name, "config.json"))
+                exist_adapter_config = os.path.exists(
+                    os.path.join(model_name, "adapter_config.json")
+                )
+                exist_config = os.path.exists(os.path.join(model_name, "config.json"))
                 both_exist = exist_adapter_config and exist_config
             else:
                 # Because HfFileSystem assumes linux paths, we need to set the path with forward slashes, even on Windows.
                 files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
                 files = list(os.path.split(x)[-1] for x in files)
-                if sum(x == "adapter_config.json" or x == "config.json" for x in files) >= 2:
+                if (
+                    sum(x == "adapter_config.json" or x == "config.json" for x in files)
+                    >= 2
+                ):
                     both_exist = True
-                pass
-            pass
-        pass
 
         if not is_model and not is_peft:
             error = autoconfig_error if autoconfig_error is not None else peft_error
             # Old transformers version
             if "rope_scaling" in error.lower() and not SUPPORTS_LLAMA31:
                 raise ImportError(
-                    f"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\n"\
-                    f"This includes Llama 3.1. The minimum required version is 4.43.2\n"\
-                    f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\
-                    f"to obtain the latest transformers build, then restart this session."\
+                    f"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\n"
+                    f"This includes Llama 3.1. The minimum required version is 4.43.2\n"
+                    f'Try `pip install --upgrade "transformers>=4.43.2"`\n'
+                    f"to obtain the latest transformers build, then restart this session."
                 )
             # Create a combined error message showing both failures
             combined_error = (
@@ -322,7 +333,6 @@ def from_pretrained(
                 f"PeftConfig error: {peft_error}\n\n"
             )
             raise RuntimeError(combined_error)
-        pass
 
         # Get base model for PEFT:
         if is_peft:
@@ -332,7 +342,9 @@ def from_pretrained(
                 model_name = get_model_name(model_name, load_in_4bit)
             # Check if pre-quantized models are allowed
             # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64
-            if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(("-unsloth-bnb-4bit", "-bnb-4bit")):
+            if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
+                ("-unsloth-bnb-4bit", "-bnb-4bit")
+            ):
                 model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
                 model_name = model_name.lower().removesuffix("-bnb-4bit")
             # Change -BF16 to all False for 4bit, 8bit etc
@@ -346,73 +358,77 @@ def from_pretrained(
                 token = token,
                 trust_remote_code = trust_remote_code,
             )
-        pass
 
-        if not was_disabled: enable_progress_bars()
+        if not was_disabled:
+            enable_progress_bars()
 
         if model_type == "llama":
             scaling_type = None
             if getattr(model_config, "rope_scaling", None) is not None:
                 scaling_type1 = model_config.rope_scaling.get("type", None)
                 scaling_type2 = model_config.rope_scaling.get("rope_type", None)
-                scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
-            pass
+                scaling_type = (
+                    scaling_type1 if scaling_type1 is not None else scaling_type2
+                )
 
             if scaling_type == "llama3" and not SUPPORTS_LLAMA31:
                 raise ImportError(
-                    f"Unsloth: Your transformers version of {transformers_version} does not support Llama 3.1.\n"\
-                    f"The minimum required version is 4.43.2\n"\
-                    f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\
-                    f"to obtain the latest transformers build, then restart this session."\
+                    f"Unsloth: Your transformers version of {transformers_version} does not support Llama 3.1.\n"
+                    f"The minimum required version is 4.43.2\n"
+                    f'Try `pip install --upgrade "transformers>=4.43.2"`\n'
+                    f"to obtain the latest transformers build, then restart this session."
                 )
 
             dispatch_model = FastLlamaModel
 
-        elif model_type == "mistral": dispatch_model = FastMistralModel
+        elif model_type == "mistral":
+            dispatch_model = FastMistralModel
         elif model_type == "gemma":
             if not SUPPORTS_GEMMA:
                 raise ImportError(
-                    f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
-                    f"The minimum required version is 4.38.\n"\
-                    f'Try `pip install --upgrade "transformers>=4.38"`\n'\
-                    f"to obtain the latest transformers build, then restart this session."\
+                    f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"
+                    f"The minimum required version is 4.38.\n"
+                    f'Try `pip install --upgrade "transformers>=4.38"`\n'
+                    f"to obtain the latest transformers build, then restart this session."
                 )
             dispatch_model = FastGemmaModel
         elif model_type == "gemma2":
             if not SUPPORTS_GEMMA2:
                 raise ImportError(
-                    f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
-                    f"The minimum required version is 4.42.3.\n"\
-                    f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
-                    f"to obtain the latest transformers build, then restart this session."\
+                    f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"
+                    f"The minimum required version is 4.42.3.\n"
+                    f'Try `pip install --upgrade "transformers>=4.42.3"`\n'
+                    f"to obtain the latest transformers build, then restart this session."
                 )
             # Also check for softcapping support in flash-attn which is faster!
             if is_bfloat16_supported() and not HAS_FLASH_ATTENTION:
                 print(
-                    "Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!\n"\
-                    "To install flash-attn, do the below:\n"\
+                    "Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!\n"
+                    "To install flash-attn, do the below:\n"
                     '\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
                 )
             elif HAS_FLASH_ATTENTION and not HAS_FLASH_ATTENTION_SOFTCAPPING:
                 print(
-                    "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
-                    "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
-                    "To update flash-attn, do the below:\n"\
+                    "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"
+                    "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"
+                    "To update flash-attn, do the below:\n"
                     '\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
                 )
 
             dispatch_model = FastGemma2Model
         elif model_type == "qwen2":
             dispatch_model = FastQwen2Model
-        elif model_type == "qwen3":# or model_type == "qwen3_moe":
+        elif model_type == "qwen3":  # or model_type == "qwen3_moe":
             if not SUPPORTS_QWEN3 or not SUPPORTS_QWEN3_MOE:
                 raise ImportError(
-                    f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3.\n"\
-                    f"The minimum required version is 4.50.3.\n"\
-                    f'Try `pip install --upgrade "transformers>=4.50.3"`\n'\
-                    f"to obtain the latest transformers build, then restart this session."\
+                    f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3.\n"
+                    f"The minimum required version is 4.50.3.\n"
+                    f'Try `pip install --upgrade "transformers>=4.50.3"`\n'
+                    f"to obtain the latest transformers build, then restart this session."
                 )
-            dispatch_model = FastQwen3Model if model_type == "qwen3" else FastQwen3MoeModel
+            dispatch_model = (
+                FastQwen3Model if model_type == "qwen3" else FastQwen3MoeModel
+            )
         # elif model_type == "falcon_h1":
         #     dispatch_model = FastFalconH1Model
         #     if not SUPPORTS_FALCON_H1:
@@ -430,108 +446,109 @@ def from_pretrained(
         #     dispatch_model = FastGraniteModel
         else:
             return FastModel.from_pretrained(
-                model_name                 = old_model_name,
-                max_seq_length             = max_seq_length,
-                dtype                      = dtype,
-                load_in_4bit               = load_in_4bit,
-                load_in_8bit               = load_in_8bit,
-                load_in_16bit              = load_in_16bit,
-                full_finetuning            = full_finetuning,
-                token                      = token,
-                device_map                 = device_map,
-                rope_scaling               = rope_scaling, # [TODO] No effect
-                fix_tokenizer              = fix_tokenizer, # [TODO] No effect
-                trust_remote_code          = trust_remote_code,
+                model_name = old_model_name,
+                max_seq_length = max_seq_length,
+                dtype = dtype,
+                load_in_4bit = load_in_4bit,
+                load_in_8bit = load_in_8bit,
+                load_in_16bit = load_in_16bit,
+                full_finetuning = full_finetuning,
+                token = token,
+                device_map = device_map,
+                rope_scaling = rope_scaling,  # [TODO] No effect
+                fix_tokenizer = fix_tokenizer,  # [TODO] No effect
+                trust_remote_code = trust_remote_code,
                 use_gradient_checkpointing = use_gradient_checkpointing,
-                resize_model_vocab         = resize_model_vocab, # [TODO] No effect
-                revision                   = revision,
-                return_logits              = False, # Return logits
-                fullgraph                  = True, # No graph breaks
-                use_exact_model_name       = use_exact_model_name,
-                offload_embedding          = offload_embedding,
-
+                resize_model_vocab = resize_model_vocab,  # [TODO] No effect
+                revision = revision,
+                return_logits = False,  # Return logits
+                fullgraph = True,  # No graph breaks
+                use_exact_model_name = use_exact_model_name,
+                offload_embedding = offload_embedding,
                 # Pass vLLM/inference parameters
-                fast_inference             = fast_inference,
-                gpu_memory_utilization     = gpu_memory_utilization,
-                float8_kv_cache            = float8_kv_cache,
-                random_state               = random_state,
-                max_lora_rank              = max_lora_rank,
-                disable_log_stats          = disable_log_stats,
-
-                *args, **kwargs,
+                fast_inference = fast_inference,
+                gpu_memory_utilization = gpu_memory_utilization,
+                float8_kv_cache = float8_kv_cache,
+                random_state = random_state,
+                max_lora_rank = max_lora_rank,
+                disable_log_stats = disable_log_stats,
+                *args,
+                **kwargs,
             )
-        pass
 
         if use_gradient_checkpointing == "unsloth":
             patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
 
         # Check if this is local model since the tokenizer gets overwritten
-        if  os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \
-            os.path.exists(os.path.join(old_model_name, "tokenizer.json")) and \
-            os.path.exists(os.path.join(old_model_name, "special_tokens_map.json")):
-
+        if (
+            os.path.exists(os.path.join(old_model_name, "tokenizer_config.json"))
+            and os.path.exists(os.path.join(old_model_name, "tokenizer.json"))
+            and os.path.exists(os.path.join(old_model_name, "special_tokens_map.json"))
+        ):
             tokenizer_name = old_model_name
         else:
             tokenizer_name = kwargs.pop("tokenizer_name", None)
-        pass
 
         if fast_inference:
             fast_inference, model_name = fast_inference_setup(model_name, model_config)
 
         model, tokenizer = dispatch_model.from_pretrained(
-            model_name        = model_name,
-            max_seq_length    = max_seq_length,
-            dtype             = _get_dtype(dtype),
-            load_in_4bit      = load_in_4bit,
-            token             = token,
-            device_map        = device_map,
-            rope_scaling      = rope_scaling,
-            fix_tokenizer     = fix_tokenizer,
-            model_patcher     = dispatch_model,
-            tokenizer_name    = tokenizer_name,
+            model_name = model_name,
+            max_seq_length = max_seq_length,
+            dtype = _get_dtype(dtype),
+            load_in_4bit = load_in_4bit,
+            token = token,
+            device_map = device_map,
+            rope_scaling = rope_scaling,
+            fix_tokenizer = fix_tokenizer,
+            model_patcher = dispatch_model,
+            tokenizer_name = tokenizer_name,
             trust_remote_code = trust_remote_code,
-            revision          = revision if not is_peft else None,
-
-            fast_inference    = fast_inference,
+            revision = revision if not is_peft else None,
+            fast_inference = fast_inference,
             gpu_memory_utilization = gpu_memory_utilization,
-            float8_kv_cache   = float8_kv_cache,
-            random_state      = random_state,
-            max_lora_rank     = max_lora_rank,
+            float8_kv_cache = float8_kv_cache,
+            random_state = random_state,
+            max_lora_rank = max_lora_rank,
             disable_log_stats = disable_log_stats,
-            *args, **kwargs,
+            *args,
+            **kwargs,
         )
 
         if resize_model_vocab is not None:
             model.resize_token_embeddings(resize_model_vocab)
-        pass
 
         # In case the model supports tagging, add the unsloth tag.
         if hasattr(model, "add_model_tags"):
-            model.add_model_tags(["unsloth",])
-        pass
+            model.add_model_tags(
+                [
+                    "unsloth",
+                ]
+            )
         if hasattr(tokenizer, "add_model_tags"):
-            tokenizer.add_model_tags(["unsloth",])
-        pass
+            tokenizer.add_model_tags(
+                [
+                    "unsloth",
+                ]
+            )
 
         if load_in_4bit:
             # Fix up bitsandbytes config
             compute_dtype = dtype_from_config(model.config)
-            quantization_config = \
-            {
+            quantization_config = {
                 # Sometimes compute_dtype is not a string!!
-                "bnb_4bit_compute_dtype"           : compute_dtype,
-                "bnb_4bit_quant_type"              : "nf4",
-                "bnb_4bit_use_double_quant"        : True,
-                "llm_int8_enable_fp32_cpu_offload" : False,
-                "llm_int8_has_fp16_weight"         : False,
-                "llm_int8_skip_modules"            : None,
-                "llm_int8_threshold"               : 6.0,
-                "load_in_4bit"                     : True,
-                "load_in_8bit"                     : False,
-                "quant_method"                     : "bitsandbytes",
+                "bnb_4bit_compute_dtype": compute_dtype,
+                "bnb_4bit_quant_type": "nf4",
+                "bnb_4bit_use_double_quant": True,
+                "llm_int8_enable_fp32_cpu_offload": False,
+                "llm_int8_has_fp16_weight": False,
+                "llm_int8_skip_modules": None,
+                "llm_int8_threshold": 6.0,
+                "load_in_4bit": True,
+                "load_in_8bit": False,
+                "quant_method": "bitsandbytes",
             }
-            model.config.update({"quantization_config" : quantization_config})
-        pass
+            model.config.update({"quantization_config": quantization_config})
 
         if is_peft:
             # From https://github.com/huggingface/peft/issues/184
@@ -546,10 +563,7 @@ def from_pretrained(
             )
             # Patch it as well!
             model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)
-        pass
         return model, tokenizer
-    pass
-pass
 
 
 from ..kernels import (
@@ -560,88 +574,94 @@ def from_pretrained(
 from transformers import (
     AutoModelForCausalLM,
 )
+
 try:
     from transformers import AutoModelForImageTextToText
+
     AutoModelForVision2Seq = AutoModelForImageTextToText
 except:
     from transformers import AutoModelForVision2Seq
-pass
 
 
 class FastModel(FastBaseModel):
     @staticmethod
     def from_pretrained(
-        model_name                 = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
-        max_seq_length             = 2048,
-        dtype                      = None,
-        load_in_4bit               = True,  # 4bit QLoRA
-        load_in_8bit               = False, # 8bit  LoRA
-        load_in_16bit              = False, # 16bit LoRA
-        full_finetuning            = False,
-        token                      = None,
-        device_map                 = "sequential",
-        rope_scaling               = None, # [TODO] No effect
-        fix_tokenizer              = True, # [TODO] No effect
-        trust_remote_code          = False,
+        model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
+        max_seq_length = 2048,
+        dtype = None,
+        load_in_4bit = True,  # 4bit QLoRA
+        load_in_8bit = False,  # 8bit  LoRA
+        load_in_16bit = False,  # 16bit LoRA
+        full_finetuning = False,
+        token = None,
+        device_map = "sequential",
+        rope_scaling = None,  # [TODO] No effect
+        fix_tokenizer = True,  # [TODO] No effect
+        trust_remote_code = False,
         use_gradient_checkpointing = "unsloth",
-        resize_model_vocab         = None, # [TODO] No effect
-        revision                   = None,
-        return_logits              = False, # Return logits
-        fullgraph                  = True, # No graph breaks
-        use_exact_model_name       = False,
-        auto_model                 = None,
-        whisper_language           = None,
-        whisper_task               = None,
-        unsloth_force_compile      = False,
-        offload_embedding          = False,
-
+        resize_model_vocab = None,  # [TODO] No effect
+        revision = None,
+        return_logits = False,  # Return logits
+        fullgraph = True,  # No graph breaks
+        use_exact_model_name = False,
+        auto_model = None,
+        whisper_language = None,
+        whisper_task = None,
+        unsloth_force_compile = False,
+        offload_embedding = False,
         # Add the missing vLLM/inference parameters
-        fast_inference             = False, # uses vLLM
-        gpu_memory_utilization     = 0.5,
-        float8_kv_cache            = False,
-        random_state               = 3407,
-        max_lora_rank              = 64,
-        disable_log_stats          = True,
-
-        qat_scheme                 = None,
-        *args, **kwargs,
+        fast_inference = False,  # uses vLLM
+        gpu_memory_utilization = 0.5,
+        float8_kv_cache = False,
+        random_state = 3407,
+        max_lora_rank = 64,
+        disable_log_stats = True,
+        qat_scheme = None,
+        *args,
+        **kwargs,
     ):
-        if token is None: token = get_token()
+        if token is None:
+            token = get_token()
         # Login to allow private models
         if token is not None:
             try:
                 from huggingface_hub import login
+
                 login(token = token)
             except:
                 pass
-        if whisper_language is not None: assert(type(whisper_language) is str)
-        if whisper_task is not None: assert(type(whisper_task) is str)
+        if whisper_language is not None:
+            assert type(whisper_language) is str
+        if whisper_task is not None:
+            assert type(whisper_task) is str
         SUPPORTS_BFLOAT16 = is_bfloat16_supported()
         if dtype is None:
             dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
         elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
-            logger.warning_once("Device does not support bfloat16. Will change to float16.")
+            logger.warning_once(
+                "Device does not support bfloat16. Will change to float16."
+            )
             dtype = torch.float16
-        assert(dtype in (torch.float16, torch.bfloat16, torch.float32))
+        assert dtype in (torch.float16, torch.bfloat16, torch.float32)
 
         patch_compiled_autograd()
         patch_compiling_bitsandbytes()
 
         if full_finetuning and (load_in_4bit or load_in_8bit):
-            print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.")
-            load_in_4bit  = False
-            load_in_8bit  = False
+            print(
+                "Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA."
+            )
+            load_in_4bit = False
+            load_in_8bit = False
             load_in_16bit = False
-        pass
 
         if int(load_in_4bit) + int(load_in_8bit) + int(load_in_16bit) >= 2:
             raise RuntimeError(
-                "Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!\n"\
-                "Also, we by default set `load_in_4bit = True`.\n"\
-                "If you want 8bit finetuning, set both `load_in_4bit = False` and `load_in_8bit = True`\n"\
+                "Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!\n"
+                "Also, we by default set `load_in_4bit = True`.\n"
+                "If you want 8bit finetuning, set both `load_in_4bit = False` and `load_in_8bit = True`\n"
                 "If you want 16bit LoRA finetuning, set `load_in_16bit = True`"
             )
-        pass
 
         if qat_scheme is not None and not full_finetuning:
             raise ValueError(
@@ -652,7 +672,9 @@ def from_pretrained(
         # Check if 4bit is allowed specifically for AMD
         if not ALLOW_BITSANDBYTES and not use_exact_model_name:
             if load_in_4bit or load_in_8bit or model_name.lower().endswith("-bnb-4bit"):
-                print("Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.")
+                print(
+                    "Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now."
+                )
             load_in_4bit = False
 
         old_model_name = model_name
@@ -660,7 +682,9 @@ def from_pretrained(
             model_name = get_model_name(model_name, load_in_4bit)
         # Check if pre-quantized models are allowed
         # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64
-        if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(("-unsloth-bnb-4bit", "-bnb-4bit")):
+        if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
+            ("-unsloth-bnb-4bit", "-bnb-4bit")
+        ):
             model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
             model_name = model_name.lower().removesuffix("-bnb-4bit")
         # Change -BF16 to all False for 4bit, 8bit etc
@@ -672,11 +696,16 @@ def from_pretrained(
         # Check modelscope
         if USE_MODELSCOPE and not os.path.exists(model_name):
             from modelscope import snapshot_download
+
             model_name = snapshot_download(model_name)
-        pass
 
         # First check if it's a normal model via AutoConfig
-        from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
+        from huggingface_hub.utils import (
+            disable_progress_bars,
+            enable_progress_bars,
+            are_progress_bars_disabled,
+        )
+
         was_disabled = are_progress_bars_disabled()
         disable_progress_bars()
 
@@ -696,7 +725,7 @@ def from_pretrained(
             autoconfig_error = str(error)
             if "architecture" in autoconfig_error:
                 raise ValueError(
-                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"\
+                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
                     f"Please update transformers via `pip install --upgrade transformers` and try again."
                 )
             is_model = False
@@ -712,19 +741,18 @@ def from_pretrained(
             peft_error = str(error)
             if "architecture" in peft_error:
                 raise ValueError(
-                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"\
+                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
                     f"Please update transformers via `pip install --upgrade transformers` and try again."
                 )
             is_peft = False
-        pass
         # Old transformers versions check
         both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32
         # Error out if both LoRA and normal model config exists.
         if both_exist:
             raise RuntimeError(
-                "Unsloth: Your repo has a LoRA adapter and a base model.\n"\
-                "You have 2 files `config.json` and `adapter_config.json`.\n"\
-                "We must only allow one config file.\n"\
+                "Unsloth: Your repo has a LoRA adapter and a base model.\n"
+                "You have 2 files `config.json` and `adapter_config.json`.\n"
+                "We must only allow one config file.\n"
                 "Please separate the LoRA and base models to 2 repos."
             )
         model_types = get_transformers_model_type(
@@ -735,96 +763,121 @@ def from_pretrained(
         # Save model types and loading method
         lowered_model_name = model_name.lower()
         string = os.environ.get("UNSLOTH_MODEL_NAME", "") + model_types_all
-        if load_in_4bit:  string += "_load_in_4bit_"
-        if load_in_8bit:  string += "_load_in_8bit_"
-        if load_in_16bit: string += "_load_in_16bit_"
+        if load_in_4bit:
+            string += "_load_in_4bit_"
+        if load_in_8bit:
+            string += "_load_in_8bit_"
+        if load_in_16bit:
+            string += "_load_in_16bit_"
         os.environ["UNSLOTH_MODEL_NAME"] = string
 
         # Check versions
-        LATEST  = '\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`'
+        LATEST = "\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`"
         NIGHTLY = '\nPlease use nightly transformers via pip install --upgrade "transformers>=4.49.0"`'
         # Pixtral
         if "pixtral" in model_types_all and transformers_version < Version("4.49.0"):
-            raise RuntimeError("Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST)
+            raise RuntimeError(
+                "Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST
+            )
         # Qwen 2.5
         elif "qwen2_5" in model_types_all and transformers_version < Version("4.49.0"):
-            raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST)
+            raise RuntimeError(
+                "Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST
+            )
         # Gemma 3N must be before Gemma 3
         elif "gemma3n" in model_types_all:
             if transformers_version < Version("4.53.0"):
-                raise RuntimeError("Unsloth: Gemma 3N only works on transformers >= 4.53.0" + LATEST)
+                raise RuntimeError(
+                    "Unsloth: Gemma 3N only works on transformers >= 4.53.0" + LATEST
+                )
             os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
-            os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \
-                "float16;torch.float16;torch.float16;"\
-                "if name.endswith('norm'): "\
-                "module._pre_set_compute_dtype = torch.float32\n"\
-                ";"\
+            os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
+                "float16;torch.float16;torch.float16;"
+                "if name.endswith('norm'): "
+                "module._pre_set_compute_dtype = torch.float32\n"
+                ";"
                 "from unsloth_zoo.temporary_patches.gemma3n import patch_Gemma3nConv_Embed_forwards; patch_Gemma3nConv_Embed_forwards()"
+            )
             # Set norms to float32 since anyways they get upcasted to float32
             # common in both gemma-3 and gemma-3n
             os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
         # Gemma 3
         elif "gemma3" in model_types_all:
             if transformers_version < Version("4.50.0.dev0"):
-                raise RuntimeError("Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY)
+                raise RuntimeError(
+                    "Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY
+                )
             # Set norms to float32 since anyways they get upcasted to float32
             # common in both gemma-3 and gemma-3n
             os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
         # Cohere
-        elif "cohere2" in model_types_all and transformers_version < Version("4.50.0.dev0"):
-            raise RuntimeError("Unsloth: Cohere's Command model only works on transformers >= 4.50.0." + NIGHTLY)
+        elif "cohere2" in model_types_all and transformers_version < Version(
+            "4.50.0.dev0"
+        ):
+            raise RuntimeError(
+                "Unsloth: Cohere's Command model only works on transformers >= 4.50.0."
+                + NIGHTLY
+            )
         # Sesame
         elif "csm" in model_types_all:
-            os.environ["UNSLOTH_COMPILE_DISABLE"] = "partial" # Inference is too slow
-            os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Sesame fails
-            os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \
-                "all;torch.float32;torch.float16;"\
-                "if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)"\
+            os.environ["UNSLOTH_COMPILE_DISABLE"] = "partial"  # Inference is too slow
+            os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"  # Sesame fails
+            os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
+                "all;torch.float32;torch.float16;"
+                "if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)"
                 ";"
+            )
         # Granite 4
-        elif 'granitemoehybrid' in model_types_all:
+        elif "granitemoehybrid" in model_types_all:
             # Granite-4 rms norms are stored as 16 bit, but we upcast
             os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
             os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
         # Olmo 2
-        elif "olmo2" in model_types_all and transformers_version < Version("4.50.0.dev0"):
-            raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY)
+        elif "olmo2" in model_types_all and transformers_version < Version(
+            "4.50.0.dev0"
+        ):
+            raise RuntimeError(
+                "Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY
+            )
         elif "falcon_h1" in model_types_all:
             # Falcon must use float32 Triton ie TRITON_F32_DEFAULT = 'ieee'
             # since Mamba kernels error out on using lower precision
-            os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \
-                "float16;torch.float32;torch.float16;"\
-                "if name.endswith(('q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'head')): module.to(torch.float16)"\
-                ";"\
+            os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
+                "float16;torch.float32;torch.float16;"
+                "if name.endswith(('q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'head')): module.to(torch.float16)"
+                ";"
                 "os.environ['TRITON_F32_DEFAULT'] = 'ieee'"
+            )
         elif "gpt_oss" in model_types_all:
             os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
             if not load_in_4bit:
                 # Only upcast MoE biases for MXFP4, not BnB
                 # Set norms to float32 since anyways they get upcasted to float32
-                os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \
-                    "all;None;None;"\
-                    "x = 'gate_up_proj_bias'\n"\
-                    "if hasattr(module, x): "\
-                    "setattr(module, x, torch.nn.Parameter(getattr(module, x).to(torch.float32)) if isinstance(getattr(module, x), torch.nn.Parameter) else getattr(module, x).to(torch.float32))\n"\
-                    ""\
-                    "x = 'down_proj_bias'\n"\
-                    "if hasattr(module, x): "\
-                    "setattr(module, x, torch.nn.Parameter(getattr(module, x).to(torch.float32)) if isinstance(getattr(module, x), torch.nn.Parameter) else getattr(module, x).to(torch.float32))\n"\
-                    ""\
+                os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
+                    "all;None;None;"
+                    "x = 'gate_up_proj_bias'\n"
+                    "if hasattr(module, x): "
+                    "setattr(module, x, torch.nn.Parameter(getattr(module, x).to(torch.float32)) if isinstance(getattr(module, x), torch.nn.Parameter) else getattr(module, x).to(torch.float32))\n"
+                    ""
+                    "x = 'down_proj_bias'\n"
+                    "if hasattr(module, x): "
+                    "setattr(module, x, torch.nn.Parameter(getattr(module, x).to(torch.float32)) if isinstance(getattr(module, x), torch.nn.Parameter) else getattr(module, x).to(torch.float32))\n"
+                    ""
                     ";"
+                )
             else:
                 # Set down projection compute dtype to be float32 for float16 machines
                 # Set norms to float32 since anyways they get upcasted to float32
-                os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \
-                    "torch.float16;torch.bfloat16;torch.float16;"\
-                    "if ('down_projs' in name) and hasattr(module, 'weight') and "\
-                    "torch.amax(dequantize_module_weight(module)) >= 0:"\
-                    "module._pre_set_compute_dtype = torch.float32\n"\
-                    ""\
-                    "if ('mlp.router' in name) and hasattr(module, 'weight'):"\
-                    "module._pre_set_compute_dtype = torch.float32\n"\
+                os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
+                    "torch.float16;torch.bfloat16;torch.float16;"
+                    "if ('down_projs' in name) and hasattr(module, 'weight') and "
+                    "torch.amax(dequantize_module_weight(module)) >= 0:"
+                    "module._pre_set_compute_dtype = torch.float32\n"
+                    ""
+                    "if ('mlp.router' in name) and hasattr(module, 'weight'):"
+                    "module._pre_set_compute_dtype = torch.float32\n"
                     ";"
+                )
             # Set norms to float32 since anyways they get upcasted to float32
             os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
         else:
@@ -833,40 +886,43 @@ def from_pretrained(
                     os.environ["UNSLOTH_COMPILE_DISABLE"] = "partial"
                     os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
                     if transformers_version < Version("4.50.0.dev0"):
-                        raise RuntimeError(f"Unsloth: {check_model_name} only works on transformers >= 4.50.0." + NIGHTLY)
+                        raise RuntimeError(
+                            f"Unsloth: {check_model_name} only works on transformers >= 4.50.0."
+                            + NIGHTLY
+                        )
                     break
-        pass
 
         if auto_model is not None:
             # All other models need to disable static cache
             os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
-        pass
 
         # New transformers need to check manually.
         if SUPPORTS_LLAMA32:
             # Check if folder exists locally
             if os.path.isdir(model_name):
-                exist_adapter_config = os.path.exists(os.path.join(model_name, "adapter_config.json"))
-                exist_config         = os.path.exists(os.path.join(model_name, "config.json"))
+                exist_adapter_config = os.path.exists(
+                    os.path.join(model_name, "adapter_config.json")
+                )
+                exist_config = os.path.exists(os.path.join(model_name, "config.json"))
                 both_exist = exist_adapter_config and exist_config
             else:
                 files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
                 files = list(os.path.split(x)[-1] for x in files)
-                if sum(x == "adapter_config.json" or x == "config.json" for x in files) >= 2:
+                if (
+                    sum(x == "adapter_config.json" or x == "config.json" for x in files)
+                    >= 2
+                ):
                     both_exist = True
-                pass
-            pass
-        pass
 
         if not is_model and not is_peft:
             error = autoconfig_error if autoconfig_error is not None else peft_error
             # Old transformers version
             if "rope_scaling" in error.lower() and not SUPPORTS_LLAMA31:
                 raise ImportError(
-                    f"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\n"\
-                    f"This includes Llama 3.1. The minimum required version is 4.43.2\n"\
-                    f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\
-                    f"to obtain the latest transformers build, then restart this session."\
+                    f"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\n"
+                    f"This includes Llama 3.1. The minimum required version is 4.43.2\n"
+                    f'Try `pip install --upgrade "transformers>=4.43.2"`\n'
+                    f"to obtain the latest transformers build, then restart this session."
                 )
             # Create a combined error message showing both failures
             combined_error = (
@@ -875,7 +931,6 @@ def from_pretrained(
                 f"PeftConfig error: {peft_error}\n\n"
             )
             raise RuntimeError(combined_error)
-        pass
 
         # Get base model for PEFT:
         if is_peft:
@@ -885,7 +940,9 @@ def from_pretrained(
                 model_name = get_model_name(model_name, load_in_4bit)
             # Check if pre-quantized models are allowed
             # For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64
-            if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(("-unsloth-bnb-4bit", "-bnb-4bit")):
+            if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
+                ("-unsloth-bnb-4bit", "-bnb-4bit")
+            ):
                 model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
                 model_name = model_name.lower().removesuffix("-bnb-4bit")
             # Change -BF16 to all False for 4bit, 8bit etc
@@ -899,9 +956,9 @@ def from_pretrained(
                 token = token,
                 trust_remote_code = trust_remote_code,
             )
-        pass
 
-        if not was_disabled: enable_progress_bars()
+        if not was_disabled:
+            enable_progress_bars()
 
         do_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
         if do_logging:
@@ -914,143 +971,146 @@ def from_pretrained(
         os.environ["UNSLOTH_FORCE_FLOAT32"] = "0"
         do_forced_float32 = False
         for model_type_arch in model_types:
-            if model_type_arch != "siglip": break
+            if model_type_arch != "siglip":
+                break
         global FORCE_FLOAT32
         for disable_name in FORCE_FLOAT32:
             # add comma to model_types_all matching in case of exact match for end
-            if (disable_name.lower() == model_type_arch.lower().replace("-", "").replace("_", "") or \
-                disable_name.lower() in model_types_all) and \
-                ((dtype == torch.float16) or not SUPPORTS_BFLOAT16):
+            if (
+                disable_name.lower()
+                == model_type_arch.lower().replace("-", "").replace("_", "")
+                or disable_name.lower() in model_types_all
+            ) and ((dtype == torch.float16) or not SUPPORTS_BFLOAT16):
                 os.environ["UNSLOTH_FORCE_FLOAT32"] = "1"
-                dtype = torch.bfloat16 # Change to bfloat16 loading
+                dtype = torch.bfloat16  # Change to bfloat16 loading
                 break
-        pass
         # Patch gradient checkpointing
         if use_gradient_checkpointing == "unsloth":
             patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
         with redirector:
             patch_loss_functions(torch_compile = False)
             model_types, supports_sdpa = unsloth_compile_transformers(
-                dtype                   = dtype,
-                model_name              = model_name,
-                model_types             = model_types,
-                token                   = token,
-                sdpa_dynamic_mask       = True,
-                sdpa_bool_masks         = True,
-                sdpa_gqa_replace        = True,
-                sdpa_dynamic_compile    = True,
-                compile_attention       = True,
-                disable_causal_masks    = True,
-                compile_torch_modules   = True,
-                compile_custom_modules  = True,
-                compile_function_calls  = True,
-                fuse_lm_head            = True,
-                gradient_checkpointing  = True,
-                manual_replacements     = True,
-                fast_lora_forwards      = True,
-                fast_residual_stream    = False,
-                accurate_accumulation   = True,
-                epilogue_fusion         = True,
-                max_autotune            = False,
-                shape_padding           = True,
-                cudagraphs              = False,
-                debug                   = False,
-                fullgraph               = fullgraph,
-                import_from_cache       = False,
-                disable                 = False,
-                return_logits           = return_logits,
-                trust_remote_code       = trust_remote_code,
-                unsloth_force_compile   = unsloth_force_compile,
+                dtype = dtype,
+                model_name = model_name,
+                model_types = model_types,
+                token = token,
+                sdpa_dynamic_mask = True,
+                sdpa_bool_masks = True,
+                sdpa_gqa_replace = True,
+                sdpa_dynamic_compile = True,
+                compile_attention = True,
+                disable_causal_masks = True,
+                compile_torch_modules = True,
+                compile_custom_modules = True,
+                compile_function_calls = True,
+                fuse_lm_head = True,
+                gradient_checkpointing = True,
+                manual_replacements = True,
+                fast_lora_forwards = True,
+                fast_residual_stream = False,
+                accurate_accumulation = True,
+                epilogue_fusion = True,
+                max_autotune = False,
+                shape_padding = True,
+                cudagraphs = False,
+                debug = False,
+                fullgraph = fullgraph,
+                import_from_cache = False,
+                disable = False,
+                return_logits = return_logits,
+                trust_remote_code = trust_remote_code,
+                unsloth_force_compile = unsloth_force_compile,
             )
-        pass
         # Fix SDPA issues
         for model_type in DISABLE_SDPA_MODEL_NAMES:
             if model_type in model_types_all:
                 supports_sdpa = False
-        pass
 
         # Check if this is local model since the tokenizer gets overwritten
-        if  os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \
-            os.path.exists(os.path.join(old_model_name, "tokenizer.json")) and \
-            os.path.exists(os.path.join(old_model_name, "special_tokens_map.json")):
-
+        if (
+            os.path.exists(os.path.join(old_model_name, "tokenizer_config.json"))
+            and os.path.exists(os.path.join(old_model_name, "tokenizer.json"))
+            and os.path.exists(os.path.join(old_model_name, "special_tokens_map.json"))
+        ):
             tokenizer_name = old_model_name
         else:
             tokenizer_name = kwargs.pop("tokenizer_name", None)
-        pass
 
         # Check if VLM
         architectures = getattr(model_config, "architectures", None)
-        if architectures is None: architectures = []
+        if architectures is None:
+            architectures = []
         is_vlm = any(x.endswith("ForConditionalGeneration") for x in architectures)
         is_vlm = is_vlm or hasattr(model_config, "vision_config")
         if auto_model is None:
             auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
 
         model, tokenizer = FastBaseModel.from_pretrained(
-            model_name        = model_name,
-            max_seq_length    = max_seq_length,
-            dtype             = _get_dtype(dtype),
-            load_in_4bit      = load_in_4bit,
-            load_in_8bit      = load_in_8bit,
-            load_in_16bit     = load_in_16bit,
-            full_finetuning   = full_finetuning,
-            token             = token,
-            device_map        = device_map,
+            model_name = model_name,
+            max_seq_length = max_seq_length,
+            dtype = _get_dtype(dtype),
+            load_in_4bit = load_in_4bit,
+            load_in_8bit = load_in_8bit,
+            load_in_16bit = load_in_16bit,
+            full_finetuning = full_finetuning,
+            token = token,
+            device_map = device_map,
             trust_remote_code = trust_remote_code,
-            revision          = revision if not is_peft else None,
-            model_types       = model_types,
-            tokenizer_name    = tokenizer_name,
-            auto_model        = auto_model,
+            revision = revision if not is_peft else None,
+            model_types = model_types,
+            tokenizer_name = tokenizer_name,
+            auto_model = auto_model,
             use_gradient_checkpointing = use_gradient_checkpointing,
-            supports_sdpa     = supports_sdpa,
-            whisper_language  = whisper_language,
-            whisper_task      = whisper_task,
-            auto_config       = model_config,
+            supports_sdpa = supports_sdpa,
+            whisper_language = whisper_language,
+            whisper_task = whisper_task,
+            auto_config = model_config,
             offload_embedding = offload_embedding,
-
             # Pass vLLM/inference parameters
-            fast_inference         = fast_inference,
+            fast_inference = fast_inference,
             gpu_memory_utilization = gpu_memory_utilization,
-            float8_kv_cache        = float8_kv_cache,
-            random_state           = random_state,
-            max_lora_rank          = max_lora_rank,
-            disable_log_stats      = disable_log_stats,
-
-            *args, **kwargs,
+            float8_kv_cache = float8_kv_cache,
+            random_state = random_state,
+            max_lora_rank = max_lora_rank,
+            disable_log_stats = disable_log_stats,
+            *args,
+            **kwargs,
         )
 
         if resize_model_vocab is not None:
             model.resize_token_embeddings(resize_model_vocab)
-        pass
 
         # In case the model supports tagging, add the unsloth tag.
         if hasattr(model, "add_model_tags"):
-            model.add_model_tags(["unsloth",])
-        pass
+            model.add_model_tags(
+                [
+                    "unsloth",
+                ]
+            )
         if hasattr(tokenizer, "add_model_tags"):
-            tokenizer.add_model_tags(["unsloth",])
-        pass
+            tokenizer.add_model_tags(
+                [
+                    "unsloth",
+                ]
+            )
 
         if load_in_4bit:
             # Fix up bitsandbytes config
             compute_dtype = dtype_from_config(model.config)
-            quantization_config = \
-            {
+            quantization_config = {
                 # Sometimes compute_dtype is not a string!!
-                "bnb_4bit_compute_dtype"           : compute_dtype,
-                "bnb_4bit_quant_type"              : "nf4",
-                "bnb_4bit_use_double_quant"        : True,
-                "llm_int8_enable_fp32_cpu_offload" : False,
-                "llm_int8_has_fp16_weight"         : False,
-                "llm_int8_skip_modules"            : None,
-                "llm_int8_threshold"               : 6.0,
-                "load_in_4bit"                     : True,
-                "load_in_8bit"                     : False,
-                "quant_method"                     : "bitsandbytes",
+                "bnb_4bit_compute_dtype": compute_dtype,
+                "bnb_4bit_quant_type": "nf4",
+                "bnb_4bit_use_double_quant": True,
+                "llm_int8_enable_fp32_cpu_offload": False,
+                "llm_int8_has_fp16_weight": False,
+                "llm_int8_skip_modules": None,
+                "llm_int8_threshold": 6.0,
+                "load_in_4bit": True,
+                "load_in_8bit": False,
+                "quant_method": "bitsandbytes",
             }
-            model.config.update({"quantization_config" : quantization_config})
-        pass
+            model.config.update({"quantization_config": quantization_config})
 
         if is_peft:
             # From https://github.com/huggingface/peft/issues/184
@@ -1064,21 +1124,21 @@ def from_pretrained(
                 trust_remote_code = trust_remote_code,
             )
             # Patch it as well!
-            model = FastBaseModel.post_patch_model(model, use_gradient_checkpointing, trust_remote_code  = trust_remote_code)
-        pass
+            model = FastBaseModel.post_patch_model(
+                model, use_gradient_checkpointing, trust_remote_code = trust_remote_code
+            )
 
         # Apply QAT if specified
         if qat_scheme is not None:
             print("Unsloth: Applying QAT to mitigate quantization degradation")
             model = _prepare_model_for_qat(model, qat_scheme)
-        pass
 
         return model, tokenizer
-    pass
-pass
+
 
 class FastVisionModel(FastModel):
     pass
 
+
 class FastTextModel(FastModel):
     pass
diff --git a/unsloth/models/loader_utils.py b/unsloth/models/loader_utils.py
index e2e24eae3..631f614cb 100644
--- a/unsloth/models/loader_utils.py
+++ b/unsloth/models/loader_utils.py
@@ -13,45 +13,45 @@
 # limitations under the License.
 
 from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
+
 # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
 from packaging.version import Version
 from transformers import __version__ as transformers_version
+
 transformers_version = Version(transformers_version)
 SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
 
-BAD_MAPPINGS = \
-{
-    "unsloth/Qwen3-32B-unsloth-bnb-4bit".lower()          : "unsloth/Qwen3-32B-bnb-4bit".lower(), # 32B dynamic quant is way too big
-    "unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit".lower()      : "unsloth/Qwen3-30B-A3B".lower(),      # HF loads MoEs too slowly
-    "unsloth/Qwen3-30B-A3B-bnb-4bit".lower()              : "unsloth/Qwen3-30B-A3B".lower(),      # We rather do it on the fly
-    "unsloth/Qwen3-30B-A3B-Base-unsloth-bnb-4bit".lower() : "unsloth/Qwen3-30B-A3B-Base".lower(), # HF loads MoEs too slowly
-    "unsloth/Qwen3-30B-A3B-Base-bnb-4bit".lower()         : "unsloth/Qwen3-30B-A3B-Base".lower(), # We rather do it on the fly
+BAD_MAPPINGS = {
+    "unsloth/Qwen3-32B-unsloth-bnb-4bit".lower(): "unsloth/Qwen3-32B-bnb-4bit".lower(),  # 32B dynamic quant is way too big
+    "unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit".lower(): "unsloth/Qwen3-30B-A3B".lower(),  # HF loads MoEs too slowly
+    "unsloth/Qwen3-30B-A3B-bnb-4bit".lower(): "unsloth/Qwen3-30B-A3B".lower(),  # We rather do it on the fly
+    "unsloth/Qwen3-30B-A3B-Base-unsloth-bnb-4bit".lower(): "unsloth/Qwen3-30B-A3B-Base".lower(),  # HF loads MoEs too slowly
+    "unsloth/Qwen3-30B-A3B-Base-bnb-4bit".lower(): "unsloth/Qwen3-30B-A3B-Base".lower(),  # We rather do it on the fly
 }
 
+
 def __get_model_name(
     model_name,
     load_in_4bit = True,
-    INT_TO_FLOAT_MAPPER  = None,
-    FLOAT_TO_INT_MAPPER  = None,
+    INT_TO_FLOAT_MAPPER = None,
+    FLOAT_TO_INT_MAPPER = None,
     MAP_TO_UNSLOTH_16bit = None,
 ):
     model_name = str(model_name)
     lower_model_name = model_name.lower()
 
     if not SUPPORTS_FOURBIT and lower_model_name in INT_TO_FLOAT_MAPPER:
-
         model_name = INT_TO_FLOAT_MAPPER[lower_model_name]
         print(
-            f"Unsloth: Your transformers version of {transformers_version} does not support native "\
-            f"4bit loading.\nThe minimum required version is 4.37.\n"\
-            f'Try `pip install --upgrade "transformers>=4.37"`\n'\
-            f"to obtain the latest transformers build, then restart this session.\n"\
+            f"Unsloth: Your transformers version of {transformers_version} does not support native "
+            f"4bit loading.\nThe minimum required version is 4.37.\n"
+            f'Try `pip install --upgrade "transformers>=4.37"`\n'
+            f"to obtain the latest transformers build, then restart this session.\n"
             f"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading)."
         )
         return model_name
-    
-    elif not load_in_4bit and lower_model_name in INT_TO_FLOAT_MAPPER:
 
+    elif not load_in_4bit and lower_model_name in INT_TO_FLOAT_MAPPER:
         new_model_name = INT_TO_FLOAT_MAPPER[lower_model_name]
         # logger.warning_once(
         #     f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\
@@ -60,80 +60,87 @@ def __get_model_name(
         return new_model_name
 
     elif not load_in_4bit and lower_model_name in MAP_TO_UNSLOTH_16bit:
-
         new_model_name = MAP_TO_UNSLOTH_16bit[lower_model_name]
         return new_model_name
 
     elif load_in_4bit and SUPPORTS_FOURBIT and lower_model_name in FLOAT_TO_INT_MAPPER:
-
         # Support returning original full -bnb-4bit name if specified specifically
         # since we'll map it to the dynamic version instead
         if lower_model_name.endswith("-bnb-4bit"):
             return lower_model_name
-        
+
         new_model_name = FLOAT_TO_INT_MAPPER[lower_model_name]
         # logger.warning_once(
         #     f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\
         #     f"We shall load `{new_model_name}` for 4x faster loading."
         # )
         return new_model_name
-    pass
 
     return None
-pass
 
 
 def _get_new_mapper():
     try:
         import requests
+
         new_mapper = "https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/models/mapper.py"
-        with requests.get(new_mapper, timeout = 3) as new_mapper: new_mapper = new_mapper.text
-        new_mapper = new_mapper[new_mapper.find("__INT_TO_FLOAT_MAPPER"):]
-        new_mapper = new_mapper\
-            .replace("INT_TO_FLOAT_MAPPER",  "NEW_INT_TO_FLOAT_MAPPER")\
-            .replace("FLOAT_TO_INT_MAPPER",  "NEW_FLOAT_TO_INT_MAPPER")\
+        with requests.get(new_mapper, timeout = 3) as new_mapper:
+            new_mapper = new_mapper.text
+        new_mapper = new_mapper[new_mapper.find("__INT_TO_FLOAT_MAPPER") :]
+        new_mapper = (
+            new_mapper.replace("INT_TO_FLOAT_MAPPER", "NEW_INT_TO_FLOAT_MAPPER")
+            .replace("FLOAT_TO_INT_MAPPER", "NEW_FLOAT_TO_INT_MAPPER")
             .replace("MAP_TO_UNSLOTH_16bit", "NEW_MAP_TO_UNSLOTH_16bit")
+        )
 
         exec(new_mapper, globals())
-        return NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit
+        return (
+            NEW_INT_TO_FLOAT_MAPPER,
+            NEW_FLOAT_TO_INT_MAPPER,
+            NEW_MAP_TO_UNSLOTH_16bit,
+        )
     except:
         return {}, {}, {}
-    pass
-pass
 
 
 def get_model_name(model_name, load_in_4bit = True):
     new_model_name = __get_model_name(
         model_name = model_name,
         load_in_4bit = load_in_4bit,
-        INT_TO_FLOAT_MAPPER  = INT_TO_FLOAT_MAPPER,
-        FLOAT_TO_INT_MAPPER  = FLOAT_TO_INT_MAPPER,
+        INT_TO_FLOAT_MAPPER = INT_TO_FLOAT_MAPPER,
+        FLOAT_TO_INT_MAPPER = FLOAT_TO_INT_MAPPER,
         MAP_TO_UNSLOTH_16bit = MAP_TO_UNSLOTH_16bit,
     )
     # In the rare case, we convert bad model names to other names
     # For eg too large dynamic quants or MoEs
-    if new_model_name is not None and type(new_model_name) is str and \
-        new_model_name.lower() in BAD_MAPPINGS:
+    if (
+        new_model_name is not None
+        and type(new_model_name) is str
+        and new_model_name.lower() in BAD_MAPPINGS
+    ):
         new_model_name = BAD_MAPPINGS[new_model_name.lower()]
 
-    if new_model_name is None and model_name.count("/") == 1 and model_name[0].isalnum():
+    if (
+        new_model_name is None
+        and model_name.count("/") == 1
+        and model_name[0].isalnum()
+    ):
         # Try checking if a new Unsloth version allows it!
-        NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = _get_new_mapper()
+        NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = (
+            _get_new_mapper()
+        )
         upgraded_model_name = __get_model_name(
             model_name = model_name,
             load_in_4bit = load_in_4bit,
-            INT_TO_FLOAT_MAPPER  = NEW_INT_TO_FLOAT_MAPPER,
-            FLOAT_TO_INT_MAPPER  = NEW_FLOAT_TO_INT_MAPPER,
+            INT_TO_FLOAT_MAPPER = NEW_INT_TO_FLOAT_MAPPER,
+            FLOAT_TO_INT_MAPPER = NEW_FLOAT_TO_INT_MAPPER,
             MAP_TO_UNSLOTH_16bit = NEW_MAP_TO_UNSLOTH_16bit,
         )
         if upgraded_model_name is not None:
             raise NotImplementedError(
-                f"Unsloth: {model_name} is not supported in your current Unsloth version! Please update Unsloth via:\n\n"\
-                'pip uninstall unsloth unsloth_zoo -y\n'\
-                'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'\
-                'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'\
+                f"Unsloth: {model_name} is not supported in your current Unsloth version! Please update Unsloth via:\n\n"
+                "pip uninstall unsloth unsloth_zoo -y\n"
+                'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
+                'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'
             )
-        pass
-    pass
     return new_model_name if new_model_name is not None else model_name
-pass
diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py
index e56554041..20ac67613 100644
--- a/unsloth/models/mapper.py
+++ b/unsloth/models/mapper.py
@@ -17,1045 +17,1024 @@
     "FLOAT_TO_INT_MAPPER",
 ]
 
-__INT_TO_FLOAT_MAPPER = \
-{
-    "unsloth/mistral-7b-bnb-4bit" : (
+__INT_TO_FLOAT_MAPPER = {
+    "unsloth/mistral-7b-bnb-4bit": (
         "unsloth/mistral-7b",
         "mistralai/Mistral-7B-v0.1",
     ),
-    "unsloth/llama-2-7b-bnb-4bit" : (
+    "unsloth/llama-2-7b-bnb-4bit": (
         "unsloth/llama-2-7b",
         "meta-llama/Llama-2-7b-hf",
     ),
-    "unsloth/llama-2-13b-bnb-4bit" : (
+    "unsloth/llama-2-13b-bnb-4bit": (
         "unsloth/llama-2-13b",
         "meta-llama/Llama-2-13b-hf",
     ),
-    "unsloth/codellama-34b-bnb-4bit" : (
-        "codellama/CodeLlama-34b-hf",
-    ),
-    "unsloth/zephyr-sft-bnb-4bit" : (
+    "unsloth/codellama-34b-bnb-4bit": ("codellama/CodeLlama-34b-hf",),
+    "unsloth/zephyr-sft-bnb-4bit": (
         "unsloth/zephyr-sft",
         "HuggingFaceH4/mistral-7b-sft-beta",
     ),
-    "unsloth/tinyllama-bnb-4bit" : (
+    "unsloth/tinyllama-bnb-4bit": (
         "unsloth/tinyllama",
         "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
     ),
-    "unsloth/tinyllama-chat-bnb-4bit" : (
+    "unsloth/tinyllama-chat-bnb-4bit": (
         "unsloth/tinyllama-chat",
         "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
     ),
-    "unsloth/mistral-7b-instruct-v0.1-bnb-4bit" : (
+    "unsloth/mistral-7b-instruct-v0.1-bnb-4bit": (
         "unsloth/mistral-7b-instruct-v0.1",
         "mistralai/Mistral-7B-Instruct-v0.1",
     ),
-    "unsloth/mistral-7b-instruct-v0.2-bnb-4bit" : (
+    "unsloth/mistral-7b-instruct-v0.2-bnb-4bit": (
         "unsloth/mistral-7b-instruct-v0.2",
         "mistralai/Mistral-7B-Instruct-v0.2",
     ),
-    "unsloth/llama-2-7b-chat-bnb-4bit" : (
+    "unsloth/llama-2-7b-chat-bnb-4bit": (
         "unsloth/llama-2-7b-chat",
         "meta-llama/Llama-2-7b-chat-hf",
     ),
-    "unsloth/llama-2-7b-chat-bnb-4bit" : (
+    "unsloth/llama-2-7b-chat-bnb-4bit": (
         "unsloth/llama-2-7b-chat",
         "meta-llama/Llama-2-7b-chat-hf",
     ),
-    "unsloth/Mixtral-8x7B-v0.1-unsloth-bnb-4bit" : (
+    "unsloth/Mixtral-8x7B-v0.1-unsloth-bnb-4bit": (
         "unsloth/Mixtral-8x7B-v0.1",
         "mistralai/Mixtral-8x7B-v0.1",
         "unsloth/Mixtral-8x7B-v0.1-bnb-4bit",
     ),
-    "unsloth/Mixtral-8x7B-Instruct-v0.1-unsloth-bnb-4bit" : (
+    "unsloth/Mixtral-8x7B-Instruct-v0.1-unsloth-bnb-4bit": (
         "unsloth/Mixtral-8x7B-Instruct-v0.1",
         "mistralai/Mixtral-8x7B-Instruct-v0.1",
         "unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit",
     ),
-    "unsloth/codellama-7b-bnb-4bit" : (
+    "unsloth/codellama-7b-bnb-4bit": (
         "unsloth/codellama-7b",
         "codellama/CodeLlama-7b-hf",
     ),
-    "unsloth/codellama-13b-bnb-4bit" : (
-        "codellama/CodeLlama-13b-hf",
-    ),
-    "unsloth/yi-6b-bnb-4bit" : (
+    "unsloth/codellama-13b-bnb-4bit": ("codellama/CodeLlama-13b-hf",),
+    "unsloth/yi-6b-bnb-4bit": (
         "unsloth/yi-6b",
         "01-ai/Yi-6B",
     ),
-    "unsloth/solar-10.7b-bnb-4bit" : (
-        "upstage/SOLAR-10.7B-v1.0",
-    ),
-    "unsloth/gemma-7b-bnb-4bit" : (
+    "unsloth/solar-10.7b-bnb-4bit": ("upstage/SOLAR-10.7B-v1.0",),
+    "unsloth/gemma-7b-bnb-4bit": (
         "unsloth/gemma-7b",
         "google/gemma-7b",
     ),
-    "unsloth/gemma-2b-bnb-4bit" : (
+    "unsloth/gemma-2b-bnb-4bit": (
         "unsloth/gemma-2b",
         "google/gemma-2b",
     ),
-    "unsloth/gemma-7b-it-bnb-4bit" : (
+    "unsloth/gemma-7b-it-bnb-4bit": (
         "unsloth/gemma-7b-it",
         "google/gemma-7b-it",
     ),
-    "unsloth/gemma-2b-bnb-4bit" : (
+    "unsloth/gemma-2b-bnb-4bit": (
         "unsloth/gemma-2b-it",
         "google/gemma-2b-it",
     ),
-    "unsloth/mistral-7b-v0.2-bnb-4bit" : (
+    "unsloth/mistral-7b-v0.2-bnb-4bit": (
         "unsloth/mistral-7b-v0.2",
         "alpindale/Mistral-7B-v0.2-hf",
     ),
-    "unsloth/gemma-1.1-2b-it-bnb-4bit" : (
+    "unsloth/gemma-1.1-2b-it-bnb-4bit": (
         "unsloth/gemma-1.1-2b-it",
         "google/gemma-1.1-2b-it",
     ),
-    "unsloth/gemma-1.1-7b-it-bnb-4bit" : (
+    "unsloth/gemma-1.1-7b-it-bnb-4bit": (
         "unsloth/gemma-1.1-7b-it",
         "google/gemma-1.1-7b-it",
     ),
-    "unsloth/Starling-LM-7B-beta" : (
+    "unsloth/Starling-LM-7B-beta": (
         "unsloth/Starling-LM-7B-beta",
         "Nexusflow/Starling-LM-7B-beta",
     ),
-    "unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit" : (
+    "unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit": (
         "unsloth/Hermes-2-Pro-Mistral-7B",
         "NousResearch/Hermes-2-Pro-Mistral-7B",
     ),
-    "unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit" : (
+    "unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit": (
         "unsloth/OpenHermes-2.5-Mistral-7B",
         "teknium/OpenHermes-2.5-Mistral-7B",
     ),
-    "unsloth/codegemma-2b-bnb-4bit" : (
+    "unsloth/codegemma-2b-bnb-4bit": (
         "unsloth/codegemma-2b",
         "google/codegemma-2b",
     ),
-    "unsloth/codegemma-7b-bnb-4bit" : (
+    "unsloth/codegemma-7b-bnb-4bit": (
         "unsloth/codegemma-7b",
         "google/codegemma-7b",
     ),
-    "unsloth/codegemma-7b-it-bnb-4bit" : (
+    "unsloth/codegemma-7b-it-bnb-4bit": (
         "unsloth/codegemma-7b-it",
         "google/codegemma-7b-it",
     ),
-    "unsloth/llama-3-8b-bnb-4bit" : (
+    "unsloth/llama-3-8b-bnb-4bit": (
         "unsloth/llama-3-8b",
         "meta-llama/Meta-Llama-3-8B",
     ),
-    "unsloth/llama-3-8b-Instruct-bnb-4bit" : (
+    "unsloth/llama-3-8b-Instruct-bnb-4bit": (
         "unsloth/llama-3-8b-Instruct",
         "meta-llama/Meta-Llama-3-8B-Instruct",
     ),
-    "unsloth/llama-3-70b-bnb-4bit" : (
-        "meta-llama/Meta-Llama-3-70B",
-    ),
-    "unsloth/llama-3-70b-Instruct-bnb-4bit" : (
-        "meta-llama/Meta-Llama-3-70B-Instruct",
-    ),
-    "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" : (
+    "unsloth/llama-3-70b-bnb-4bit": ("meta-llama/Meta-Llama-3-70B",),
+    "unsloth/llama-3-70b-Instruct-bnb-4bit": ("meta-llama/Meta-Llama-3-70B-Instruct",),
+    "unsloth/Phi-3-mini-4k-instruct-bnb-4bit": (
         "unsloth/Phi-3-mini-4k-instruct",
         "microsoft/Phi-3-mini-4k-instruct",
     ),
-    "unsloth/mistral-7b-v0.3-bnb-4bit" : (
+    "unsloth/mistral-7b-v0.3-bnb-4bit": (
         "unsloth/mistral-7b-v0.3",
         "mistralai/Mistral-7B-v0.3",
     ),
-    "unsloth/mistral-7b-instruct-v0.3-bnb-4bit" : (
+    "unsloth/mistral-7b-instruct-v0.3-bnb-4bit": (
         "unsloth/mistral-7b-instruct-v0.3",
         "mistralai/Mistral-7B-Instruct-v0.3",
     ),
-    "unsloth/Phi-3-medium-4k-instruct-bnb-4bit" : (
+    "unsloth/Phi-3-medium-4k-instruct-bnb-4bit": (
         "unsloth/Phi-3-medium-4k-instruct",
         "microsoft/Phi-3-medium-4k-instruct",
     ),
-    "unsloth/Qwen2-0.5B-bnb-4bit" : (
+    "unsloth/Qwen2-0.5B-bnb-4bit": (
         "unsloth/Qwen2-0.5B",
         "Qwen/Qwen2-0.5B",
     ),
-    "unsloth/Qwen2-0.5B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2-0.5B-Instruct-bnb-4bit": (
         "unsloth/Qwen2-0.5B-Instruct",
         "Qwen/Qwen2-0.5B-Instruct",
     ),
-    "unsloth/Qwen2-1.5B-bnb-4bit" : (
+    "unsloth/Qwen2-1.5B-bnb-4bit": (
         "unsloth/Qwen2-1.5B",
         "Qwen/Qwen2-1.5B",
     ),
-    "unsloth/Qwen2-1.5B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2-1.5B-Instruct-bnb-4bit": (
         "unsloth/Qwen2-1.5B-Instruct",
         "Qwen/Qwen2-1.5B-Instruct",
     ),
-    "unsloth/Qwen2-7B-bnb-4bit" : (
+    "unsloth/Qwen2-7B-bnb-4bit": (
         "unsloth/Qwen2-7B",
         "Qwen/Qwen2-7B",
     ),
-    "unsloth/Qwen2-7B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2-7B-Instruct-bnb-4bit": (
         "unsloth/Qwen2-7B-Instruct",
         "Qwen/Qwen2-7B-Instruct",
     ),
-    "unsloth/Qwen2-70B-bnb-4bit" : (
-        "Qwen/Qwen2-70B",
-    ),
-    "unsloth/Qwen2-70B-Instruct-bnb-4bit" : (
-        "Qwen/Qwen2-70B-Instruct",
-    ),
-    "mistralai/Codestral-22B-v0.1" : (
-        "mistral-community/Codestral-22B-v0.1",
-    ),
-    "unsloth/gemma-2-9b-bnb-4bit" : (
+    "unsloth/Qwen2-70B-bnb-4bit": ("Qwen/Qwen2-70B",),
+    "unsloth/Qwen2-70B-Instruct-bnb-4bit": ("Qwen/Qwen2-70B-Instruct",),
+    "mistralai/Codestral-22B-v0.1": ("mistral-community/Codestral-22B-v0.1",),
+    "unsloth/gemma-2-9b-bnb-4bit": (
         "unsloth/gemma-2-9b",
         "google/gemma-2-9b",
     ),
-    "unsloth/gemma-2-27b-bnb-4bit" : (
+    "unsloth/gemma-2-27b-bnb-4bit": (
         "unsloth/gemma-2-27b",
         "google/gemma-2-27b",
     ),
-    "unsloth/gemma-2-9b-it-bnb-4bit" : (
+    "unsloth/gemma-2-9b-it-bnb-4bit": (
         "unsloth/gemma-2-9b-it",
         "google/gemma-2-9b-it",
     ),
-    "unsloth/gemma-2-27b-it-bnb-4bit" : (
+    "unsloth/gemma-2-27b-it-bnb-4bit": (
         "unsloth/gemma-2-27b-it",
         "google/gemma-2-27b-it",
     ),
-    "unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit" : ( # Old Phi pre July
+    "unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit": (  # Old Phi pre July
         "unsloth/Phi-3-mini-4k-instruct-v0",
     ),
-    "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit" : ( # New 12b Mistral models
+    "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit": (  # New 12b Mistral models
         "unsloth/Mistral-Nemo-Instruct-2407",
         "mistralai/Mistral-Nemo-Instruct-2407",
     ),
-    "unsloth/Mistral-Nemo-Base-2407-bnb-4bit" : ( # New 12b Mistral models
+    "unsloth/Mistral-Nemo-Base-2407-bnb-4bit": (  # New 12b Mistral models
         "unsloth/Mistral-Nemo-Base-2407",
         "mistralai/Mistral-Nemo-Base-2407",
     ),
-    "unsloth/Meta-Llama-3.1-8B-unsloth-bnb-4bit" : (
+    "unsloth/Meta-Llama-3.1-8B-unsloth-bnb-4bit": (
         "unsloth/Meta-Llama-3.1-8B",
         "meta-llama/Meta-Llama-3.1-8B",
         "unsloth/Meta-Llama-3.1-8B-bnb-4bit",
     ),
-    "unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Meta-Llama-3.1-8B-Instruct",
         "meta-llama/Meta-Llama-3.1-8B-Instruct",
         "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
     ),
-    "unsloth/Llama-3.1-8B-unsloth-bnb-4bit" : (
+    "unsloth/Llama-3.1-8B-unsloth-bnb-4bit": (
         "unsloth/Llama-3.1-8B",
         "meta-llama/Llama-3.1-8B",
         "unsloth/Llama-3.1-8B-bnb-4bit",
     ),
-    "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Llama-3.1-8B-Instruct",
         "meta-llama/Llama-3.1-8B-Instruct",
         "unsloth/Llama-3.1-8B-Instruct-bnb-4bit",
     ),
-    "unsloth/Meta-Llama-3.1-70B-bnb-4bit" : (
+    "unsloth/Meta-Llama-3.1-70B-bnb-4bit": (
         "unsloth/Meta-Llama-3.1-70B",
         "meta-llama/Meta-Llama-3.1-70B",
     ),
-    "unsloth/Meta-Llama-3.1-405B-bnb-4bit" : (
-        "meta-llama/Meta-Llama-3.1-405B",
-    ),
-    "unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit" : (
+    "unsloth/Meta-Llama-3.1-405B-bnb-4bit": ("meta-llama/Meta-Llama-3.1-405B",),
+    "unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit": (
         "meta-llama/Meta-Llama-3.1-405B-Instruct",
     ),
-    "unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit" : (
+    "unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit": (
         "unsloth/Meta-Llama-3.1-70B-Instruct",
         "meta-llama/Meta-Llama-3.1-70B-Instruct",
     ),
-    "unsloth/Mistral-Large-Instruct-2407-bnb-4bit" : (
+    "unsloth/Mistral-Large-Instruct-2407-bnb-4bit": (
         "mistralai/Mistral-Large-Instruct-2407",
     ),
-    "unsloth/gemma-2-2b-bnb-4bit" : (
+    "unsloth/gemma-2-2b-bnb-4bit": (
         "unsloth/gemma-2-2b",
         "google/gemma-2-2b",
     ),
-    "unsloth/gemma-2-2b-it-bnb-4bit" : (
+    "unsloth/gemma-2-2b-it-bnb-4bit": (
         "unsloth/gemma-2-2b-it",
         "google/gemma-2-2b-it",
     ),
-    "unsloth/Phi-3.5-mini-instruct-bnb-4bit" : (
+    "unsloth/Phi-3.5-mini-instruct-bnb-4bit": (
         "unsloth/Phi-3.5-mini-instruct",
         "microsoft/Phi-3.5-mini-instruct",
     ),
-    "unsloth/c4ai-command-r-08-2024-bnb-4bit" : (
-        "CohereForAI/c4ai-command-r-08-2024",
-    ),
-    "unsloth/c4ai-command-r-plus-08-2024-bnb-4bit" : (
+    "unsloth/c4ai-command-r-08-2024-bnb-4bit": ("CohereForAI/c4ai-command-r-08-2024",),
+    "unsloth/c4ai-command-r-plus-08-2024-bnb-4bit": (
         "CohereForAI/c4ai-command-r-plus-08-2024",
     ),
-    "unsloth/Llama-3.1-Storm-8B-bnb-4bit" : (
+    "unsloth/Llama-3.1-Storm-8B-bnb-4bit": (
         "unsloth/Llama-3.1-Storm-8B",
         "akjindal53244/Llama-3.1-Storm-8B",
     ),
-    "unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit" : (
+    "unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit": (
         "unsloth/Hermes-3-Llama-3.1-8B",
         "NousResearch/Hermes-3-Llama-3.1-8B",
     ),
-    "unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit" : (
+    "unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit": (
         "unsloth/Hermes-3-Llama-3.1-70B",
         "NousResearch/Hermes-3-Llama-3.1-70B",
     ),
-    "unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit" : (
+    "unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit": (
         "NousResearch/Hermes-3-Llama-3.1-405B",
     ),
-    "unsloth/SmolLM-135M-bnb-4bit" : (
+    "unsloth/SmolLM-135M-bnb-4bit": (
         "unsloth/SmolLM-135M",
         "HuggingFaceTB/SmolLM-135M",
     ),
-    "unsloth/SmolLM-360M-bnb-4bit" : (
+    "unsloth/SmolLM-360M-bnb-4bit": (
         "unsloth/SmolLM-360M",
         "HuggingFaceTB/SmolLM-360M",
     ),
-    "unsloth/SmolLM-1.7B-bnb-4bit" : (
+    "unsloth/SmolLM-1.7B-bnb-4bit": (
         "unsloth/SmolLM-1.7B",
         "HuggingFaceTB/SmolLM-1.7B",
     ),
-    "unsloth/SmolLM-135M-Instruct-bnb-4bit" : (
+    "unsloth/SmolLM-135M-Instruct-bnb-4bit": (
         "unsloth/SmolLM-135M-Instruct",
         "HuggingFaceTB/SmolLM-135M-Instruct",
     ),
-    "unsloth/SmolLM-360M-Instruct-bnb-4bit" : (
+    "unsloth/SmolLM-360M-Instruct-bnb-4bit": (
         "unsloth/SmolLM-360M-Instruct",
         "HuggingFaceTB/SmolLM-360M-Instruct",
     ),
-    "unsloth/SmolLM-1.7B-Instruct-bnb-4bit" : (
+    "unsloth/SmolLM-1.7B-Instruct-bnb-4bit": (
         "unsloth/SmolLM-1.7B-Instruct",
         "HuggingFaceTB/SmolLM-1.7B-Instruct",
     ),
-    "unsloth/Mistral-Small-Instruct-2409-bnb-4bit" : (
+    "unsloth/Mistral-Small-Instruct-2409-bnb-4bit": (
         "unsloth/Mistral-Small-Instruct-2409",
         "mistralai/Mistral-Small-Instruct-2409",
     ),
-    "unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-0.5B-Instruct",
         "Qwen/Qwen2.5-0.5B-Instruct",
         "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-1.5B-Instruct",
         "Qwen/Qwen2.5-1.5B-Instruct",
         "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-3B-Instruct",
         "Qwen/Qwen2.5-3B-Instruct",
         "unsloth/Qwen2.5-3B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-7B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-7B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-7B-Instruct",
         "Qwen/Qwen2.5-7B-Instruct",
         "unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-14B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-14B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-14B-Instruct",
         "Qwen/Qwen2.5-14B-Instruct",
         "unsloth/Qwen2.5-14B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-32B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2.5-32B-Instruct-bnb-4bit": (
         "unsloth/Qwen2.5-32B-Instruct",
         "Qwen/Qwen2.5-32B-Instruct",
     ),
-    "unsloth/Qwen2.5-72B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2.5-72B-Instruct-bnb-4bit": (
         "unsloth/Qwen2.5-72B-Instruct",
         "Qwen/Qwen2.5-72B-Instruct",
     ),
-    "unsloth/Qwen2.5-0.5B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-0.5B-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-0.5B",
         "Qwen/Qwen2.5-0.5B",
         "unsloth/Qwen2.5-0.5B-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-1.5B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-1.5B-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-1.5B",
         "Qwen/Qwen2.5-1.5B",
         "unsloth/Qwen2.5-1.5B-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-3B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-3B-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-3B",
         "Qwen/Qwen2.5-3B",
         "unsloth/Qwen2.5-3B-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-7B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-7B-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-7B",
         "Qwen/Qwen2.5-7B",
         "unsloth/Qwen2.5-7B-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-14B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-14B-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-14B",
         "Qwen/Qwen2.5-14B",
         "unsloth/Qwen2.5-14B-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-32B-bnb-4bit" : (
+    "unsloth/Qwen2.5-32B-bnb-4bit": (
         "unsloth/Qwen2.5-32B",
         "Qwen/Qwen2.5-32B",
     ),
-    "unsloth/Qwen2.5-72B-bnb-4bit" : (
+    "unsloth/Qwen2.5-72B-bnb-4bit": (
         "unsloth/Qwen2.5-72B",
         "Qwen/Qwen2.5-72B",
     ),
-    "unsloth/Qwen2.5-Math-1.5B-bnb-4bit" : (
+    "unsloth/Qwen2.5-Math-1.5B-bnb-4bit": (
         "unsloth/Qwen2.5-Math-1.5B",
         "Qwen/Qwen2.5-Math-1.5B",
     ),
-    "unsloth/Qwen2.5-Math-7B-bnb-4bit" : (
+    "unsloth/Qwen2.5-Math-7B-bnb-4bit": (
         "unsloth/Qwen2.5-Math-7B",
         "Qwen/Qwen2.5-Math-7B",
     ),
-    "unsloth/Qwen2.5-Math-72B-bnb-4bit" : (
+    "unsloth/Qwen2.5-Math-72B-bnb-4bit": (
         "unsloth/Qwen2.5-Math-72B",
         "Qwen/Qwen2.5-Math-72B",
     ),
-    "unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit": (
         "unsloth/Qwen2.5-Math-1.5B-Instruct",
         "Qwen/Qwen2.5-Math-1.5B-Instruct",
     ),
-    "unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit": (
         "unsloth/Qwen2.5-Math-7B-Instruct",
         "Qwen/Qwen2.5-Math-7B-Instruct",
     ),
-    "unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit": (
         "unsloth/Qwen2.5-Math-72B-Instruct",
         "Qwen/Qwen2.5-Math-72B-Instruct",
     ),
-    "unsloth/Qwen2.5-Coder-0.5B-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-0.5B-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-0.5B",
         "Qwen/Qwen2.5-Coder-0.5B",
     ),
-    "unsloth/Qwen2.5-Coder-1.5B-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-1.5B-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-1.5B",
         "Qwen/Qwen2.5-Coder-1.5B",
     ),
-    "unsloth/Qwen2.5-Coder-3B-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-3B-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-3B",
         "Qwen/Qwen2.5-Coder-3B",
     ),
-    "unsloth/Qwen2.5-Coder-7B-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-7B-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-7B",
         "Qwen/Qwen2.5-Coder-7B",
     ),
-    "unsloth/Qwen2.5-Coder-14B-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-14B-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-14B",
         "Qwen/Qwen2.5-Coder-14B",
     ),
-    "unsloth/Qwen2.5-Coder-32B-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-32B-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-32B",
         "Qwen/Qwen2.5-Coder-32B",
     ),
-    "unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-0.5B-Instruct",
         "Qwen/Qwen2.5-Coder-0.5B-Instruct",
     ),
-    "unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-1.5B-Instruct",
         "Qwen/Qwen2.5-Coder-1.5B-Instruct",
     ),
-    "unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-3B-Instruct",
         "Qwen/Qwen2.5-Coder-3B-Instruct",
     ),
-    "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-7B-Instruct",
         "Qwen/Qwen2.5-Coder-7B-Instruct",
     ),
-    "unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-14B-Instruct",
         "Qwen/Qwen2.5-Coder-14B-Instruct",
     ),
-    "unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit": (
         "unsloth/Qwen2.5-Coder-32B-Instruct",
         "Qwen/Qwen2.5-Coder-32B-Instruct",
     ),
-    "unsloth/Llama-3.2-1B-unsloth-bnb-4bit" : (
+    "unsloth/Llama-3.2-1B-unsloth-bnb-4bit": (
         "unsloth/Llama-3.2-1B",
         "meta-llama/Llama-3.2-1B",
         "unsloth/Llama-3.2-1B-bnb-4bit",
     ),
-    "unsloth/Llama-3.2-3B-unsloth-bnb-4bit" : (
+    "unsloth/Llama-3.2-3B-unsloth-bnb-4bit": (
         "unsloth/Llama-3.2-3B",
         "meta-llama/Llama-3.2-3B",
         "unsloth/Llama-3.2-3B-bnb-4bit",
     ),
-    "unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Llama-3.2-1B-Instruct",
         "meta-llama/Llama-3.2-1B-Instruct",
         "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
     ),
-    "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Llama-3.2-3B-Instruct",
         "meta-llama/Llama-3.2-3B-Instruct",
         "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
     ),
-    "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit" : (
+    "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit": (
         "unsloth/Llama-3.1-Nemotron-70B-Instruct",
         "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
     ),
-    "unsloth/Qwen2-VL-2B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2-VL-2B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen2-VL-2B-Instruct",
         "Qwen/Qwen2-VL-2B-Instruct",
         "unsloth/Qwen2-VL-2B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen2-VL-7B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2-VL-7B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen2-VL-7B-Instruct",
         "Qwen/Qwen2-VL-7B-Instruct",
         "unsloth/Qwen2-VL-7B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen2-VL-72B-Instruct-bnb-4bit" : (
+    "unsloth/Qwen2-VL-72B-Instruct-bnb-4bit": (
         "unsloth/Qwen2-VL-72B-Instruct",
         "Qwen/Qwen2-VL-72B-Instruct",
     ),
-    "unsloth/Qwen2-VL-2B-bnb-4bit" : (
+    "unsloth/Qwen2-VL-2B-bnb-4bit": (
         "unsloth/Qwen2-VL-2B",
         "Qwen/Qwen2-VL-2B",
     ),
-    "unsloth/Qwen2-VL-7B-bnb-4bit" : (
+    "unsloth/Qwen2-VL-7B-bnb-4bit": (
         "unsloth/Qwen2-VL-7B",
         "Qwen/Qwen2-VL-7B",
     ),
-    "unsloth/Qwen2-VL-72B-bnb-4bit" : (
+    "unsloth/Qwen2-VL-72B-bnb-4bit": (
         "unsloth/Qwen2-VL-72B",
         "Qwen/Qwen2-VL-72B",
     ),
-    "unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit": (
         "unsloth/Llama-3.2-11B-Vision-Instruct",
         "meta-llama/Llama-3.2-11B-Vision-Instruct",
         "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
     ),
-    "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit" : (
+    "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit": (
         "unsloth/Llama-3.2-90B-Vision-Instruct",
         "meta-llama/Llama-3.2-90B-Vision-Instruct",
     ),
-    "unsloth/Llama-3.2-11B-Vision-unsloth-bnb-4bit" : (
+    "unsloth/Llama-3.2-11B-Vision-unsloth-bnb-4bit": (
         "unsloth/Llama-3.2-11B-Vision",
         "meta-llama/Llama-3.2-11B-Vision",
         "unsloth/Llama-3.2-11B-Vision-bnb-4bit",
     ),
-    "unsloth/Llama-3.2-90B-Vision-bnb-4bit" : (
+    "unsloth/Llama-3.2-90B-Vision-bnb-4bit": (
         "unsloth/Llama-3.2-90B-Vision",
         "meta-llama/Llama-3.2-90B-Vision",
     ),
-    "unsloth/Pixtral-12B-2409-unsloth-bnb-4bit" : (
+    "unsloth/Pixtral-12B-2409-unsloth-bnb-4bit": (
         "unsloth/Pixtral-12B-2409",
         "mistralai/Pixtral-12B-2409",
         "unsloth/Pixtral-12B-2409-bnb-4bit",
     ),
-    "unsloth/Pixtral-12B-2409-Base-bnb-4bit" : (
+    "unsloth/Pixtral-12B-2409-Base-bnb-4bit": (
         "unsloth/Pixtral-12B-Base-2409",
         "mistralai/Pixtral-12B-Base-2409",
     ),
-    "unsloth/llava-1.5-7b-hf-bnb-4bit" : (
+    "unsloth/llava-1.5-7b-hf-bnb-4bit": (
         "unsloth/llava-1.5-7b-hf",
         "llava-hf/llava-1.5-7b-hf",
     ),
-    "unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit" : (
+    "unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit": (
         "unsloth/llava-v1.6-mistral-7b-hf",
         "llava-hf/llava-v1.6-mistral-7b-hf",
     ),
-    "unsloth/Llama-3.1-Tulu-3-8B-bnb-4bit" : (
+    "unsloth/Llama-3.1-Tulu-3-8B-bnb-4bit": (
         "unsloth/Llama-3.1-Tulu-3-8B",
         "allenai/Llama-3.1-Tulu-3-8B",
     ),
-    "unsloth/Llama-3.1-Tulu-3-70B-bnb-4bit" : (
+    "unsloth/Llama-3.1-Tulu-3-70B-bnb-4bit": (
         "unsloth/Llama-3.1-Tulu-3-70B",
         "allenai/Llama-3.1-Tulu-3-70B",
     ),
-    "unsloth/QwQ-32B-Preview-bnb-4bit" : (
+    "unsloth/QwQ-32B-Preview-bnb-4bit": (
         "unsloth/QwQ-32B-Preview",
         "Qwen/QwQ-32B-Preview",
     ),
-    "unsloth/Llama-3.3-70B-Instruct-bnb-4bit" : (
+    "unsloth/Llama-3.3-70B-Instruct-bnb-4bit": (
         "unsloth/Llama-3.3-70B-Instruct",
         "meta-llama/Llama-3.3-70B-Instruct",
     ),
-    "unsloth/phi-4-unsloth-bnb-4bit" : (
+    "unsloth/phi-4-unsloth-bnb-4bit": (
         "unsloth/phi-4",
         "microsoft/phi-4",
         "unsloth/phi-4-bnb-4bit",
     ),
-    "unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit" : (
+    "unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit": (
         "unsloth/DeepSeek-R1-Distill-Qwen-32B",
         "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
     ),
-    "unsloth/DeepSeek-R1-Distill-Qwen-14B-unsloth-bnb-4bit" : (
+    "unsloth/DeepSeek-R1-Distill-Qwen-14B-unsloth-bnb-4bit": (
         "unsloth/DeepSeek-R1-Distill-Qwen-14B",
         "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
         "unsloth/DeepSeek-R1-Distill-Qwen-14B-bnb-4bit",
     ),
-    "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit" : (
+    "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit": (
         "unsloth/DeepSeek-R1-Distill-Qwen-7B",
         "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
         "unsloth/DeepSeek-R1-Distill-Qwen-7B-bnb-4bit",
     ),
-    "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-unsloth-bnb-4bit" : (
+    "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-unsloth-bnb-4bit": (
         "unsloth/DeepSeek-R1-Distill-Qwen-1.5B",
         "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
         "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-bnb-4bit",
     ),
-    "unsloth/DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit" : (
+    "unsloth/DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit": (
         "unsloth/DeepSeek-R1-Distill-Llama-8B",
         "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
         "unsloth/DeepSeek-R1-Distill-Llama-8B-bnb-4bit",
     ),
-    "unsloth/DeepSeek-R1-Distill-Llama-70B-bnb-4bit" : (
+    "unsloth/DeepSeek-R1-Distill-Llama-70B-bnb-4bit": (
         "unsloth/DeepSeek-R1-Distill-Llama-70B",
         "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
     ),
-    "unsloth/Mistral-Small-24B-Base-2501-unsloth-bnb-4bit" : (
+    "unsloth/Mistral-Small-24B-Base-2501-unsloth-bnb-4bit": (
         "unsloth/Mistral-Small-24B-Base-2501",
         "mistralai/Mistral-Small-24B-Base-2501",
         "unsloth/Mistral-Small-24B-Base-2501-bnb-4bit",
     ),
-    "unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit" : (
+    "unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit": (
         "unsloth/Mistral-Small-24B-Instruct-2501",
         "mistralai/Mistral-Small-24B-Instruct-2501",
         "unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-VL-3B-Instruct",
         "Qwen/Qwen2.5-VL-3B-Instruct",
         "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-VL-7B-Instruct",
         "Qwen/Qwen2.5-VL-7B-Instruct",
         "unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-VL-32B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-VL-32B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-VL-32B-Instruct",
         "Qwen/Qwen2.5-VL-32B-Instruct",
         "unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen2.5-VL-72B-Instruct",
         "Qwen/Qwen2.5-VL-72B-Instruct",
         "unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit",
     ),
-    "unsloth/DeepScaleR-1.5B-Preview-unsloth-bnb-4bit" : (
+    "unsloth/DeepScaleR-1.5B-Preview-unsloth-bnb-4bit": (
         "unsloth/DeepHermes-3-Llama-3-8B-Preview",
         "agentica-org/DeepScaleR-1.5B-Preview",
         "unsloth/DeepScaleR-1.5B-Preview-bnb-4bit",
     ),
-    "unsloth/OpenThinker-7B-unsloth-bnb-4bit" : (
+    "unsloth/OpenThinker-7B-unsloth-bnb-4bit": (
         "unsloth/OpenThinker-7B",
         "open-thoughts/OpenThinker-7B",
         "unsloth/OpenThinker-7B-bnb-4bit",
     ),
-    "unsloth/granite-3.2-2b-instruct-unsloth-bnb-4bit" : (
+    "unsloth/granite-3.2-2b-instruct-unsloth-bnb-4bit": (
         "unsloth/granite-3.2-2b-instruct",
         "ibm-granite/granite-3.2-2b-instruct",
         "unsloth/granite-3.2-2b-instruct-bnb-4bit",
     ),
-    "unsloth/granite-3.2-8b-instruct-unsloth-bnb-4bit" : (
+    "unsloth/granite-3.2-8b-instruct-unsloth-bnb-4bit": (
         "unsloth/granite-3.2-8b-instruct",
         "ibm-granite/granite-3.2-8b-instruct",
         "unsloth/granite-3.2-8b-instruct-bnb-4bit",
     ),
-    "unsloth/QwQ-32B-unsloth-bnb-4bit" : (
+    "unsloth/QwQ-32B-unsloth-bnb-4bit": (
         "unsloth/QwQ-32B",
         "Qwen/QwQ-32B",
         "unsloth/QwQ-32B-bnb-4bit",
     ),
-    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit": (
         "unsloth/gemma-3-1b-it",
         "google/gemma-3-1b-it",
         "unsloth/gemma-3-1b-it-bnb-4bit",
     ),
-    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit": (
         "unsloth/gemma-3-4b-it",
         "google/gemma-3-4b-it",
         "unsloth/gemma-3-4b-it-bnb-4bit",
     ),
-    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit": (
         "unsloth/gemma-3-12b-it",
         "google/gemma-3-12b-it",
         "unsloth/gemma-3-12b-it-bnb-4bit",
     ),
-    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit": (
         "unsloth/gemma-3-27b-it",
         "google/gemma-3-27b-it",
         "unsloth/gemma-3-27b-it-bnb-4bit",
     ),
-    "unsloth/gemma-3-1b-pt-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3-1b-pt-unsloth-bnb-4bit": (
         "unsloth/gemma-3-1b-pt",
         "google/gemma-3-1b-pt",
         "unsloth/gemma-3-1b-pt-bnb-4bit",
     ),
-    "unsloth/gemma-3-4b-pt-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3-4b-pt-unsloth-bnb-4bit": (
         "unsloth/gemma-3-4b-pt",
         "google/gemma-3-4b-pt",
         "unsloth/gemma-3-4b-pt-bnb-4bit",
     ),
-    "unsloth/gemma-3-12b-pt-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3-12b-pt-unsloth-bnb-4bit": (
         "unsloth/gemma-3-12b-pt",
         "google/gemma-3-12b-pt",
         "unsloth/gemma-3-12b-pt-bnb-4bit",
     ),
-    "unsloth/gemma-3-27b-pt-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3-27b-pt-unsloth-bnb-4bit": (
         "unsloth/gemma-3-27b-pt",
         "google/gemma-3-27b-pt",
         "unsloth/gemma-3-27b-pt-bnb-4bit",
     ),
-    "unsloth/reka-flash-3-unsloth-bnb-4bit" : (
+    "unsloth/reka-flash-3-unsloth-bnb-4bit": (
         "unsloth/reka-flash-3",
         "RekaAI/reka-flash-3",
         "unsloth/reka-flash-3-bnb-4bit",
     ),
-    "unsloth/c4ai-command-a-03-2025-unsloth-bnb-4bit" : (
+    "unsloth/c4ai-command-a-03-2025-unsloth-bnb-4bit": (
         "unsloth/c4ai-command-a-03-2025",
         "CohereForAI/c4ai-command-a-03-2025",
         "unsloth/c4ai-command-a-03-2025-bnb-4bit",
     ),
-    "unsloth/aya-vision-32b-unsloth-bnb-4bit" : (
+    "unsloth/aya-vision-32b-unsloth-bnb-4bit": (
         "unsloth/aya-vision-32b",
         "CohereForAI/aya-vision-32b",
         "unsloth/aya-vision-32b-bnb-4bit",
     ),
-    "unsloth/aya-vision-8b-unsloth-bnb-4bit" : (
+    "unsloth/aya-vision-8b-unsloth-bnb-4bit": (
         "unsloth/aya-vision-8b",
         "CohereForAI/aya-vision-8b",
         "unsloth/aya-vision-8b-bnb-4bit",
     ),
-    "unsloth/granite-vision-3.2-2b-unsloth-bnb-4bit" : (
+    "unsloth/granite-vision-3.2-2b-unsloth-bnb-4bit": (
         "unsloth/granite-vision-3.2-2b",
         "ibm-granite/granite-vision-3.2-2b",
         "unsloth/granite-vision-3.2-2b-bnb-4bit",
     ),
-    "unsloth/OLMo-2-0325-32B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/OLMo-2-0325-32B-Instruct-unsloth-bnb-4bit": (
         "unsloth/OLMo-2-0325-32B-Instruct",
         "allenai/OLMo-2-0325-32B-Instruct",
         "unsloth/OLMo-2-0325-32B-Instruct-bnb-4bit",
     ),
-    "unsloth/Mistral-Small-3.1-24B-Instruct-2503-unsloth-bnb-4bit" : (
+    "unsloth/Mistral-Small-3.1-24B-Instruct-2503-unsloth-bnb-4bit": (
         "unsloth/Mistral-Small-3.1-24B-Instruct-2503",
         "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
         "unsloth/Mistral-Small-3.1-24B-Instruct-2503-bnb-4bit",
     ),
-    "unsloth/Mistral-Small-3.1-24B-Base-2503-unsloth-bnb-4bit" : (
+    "unsloth/Mistral-Small-3.1-24B-Base-2503-unsloth-bnb-4bit": (
         "unsloth/Mistral-Small-3.1-24B-Base-2503",
         "mistralai/Mistral-Small-3.1-24B-Base-2503",
         "unsloth/Mistral-Small-3.1-24B-Base-2503-bnb-4bit",
     ),
-    "unsloth/Qwen3-0.6B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-0.6B-unsloth-bnb-4bit": (
         "unsloth/Qwen3-0.6B",
         "Qwen/Qwen3-0.6B",
         "unsloth/Qwen3-0.6B-bnb-4bit",
     ),
-    "unsloth/Qwen3-1.7B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-1.7B-unsloth-bnb-4bit": (
         "unsloth/Qwen3-1.7B",
         "Qwen/Qwen3-1.7B",
         "unsloth/Qwen3-1.7B-bnb-4bit",
     ),
-    "unsloth/Qwen3-4B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-4B-unsloth-bnb-4bit": (
         "unsloth/Qwen3-4B",
         "Qwen/Qwen3-4B",
         "unsloth/Qwen3-4B-bnb-4bit",
     ),
-    "unsloth/Qwen3-8B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-8B-unsloth-bnb-4bit": (
         "unsloth/Qwen3-8B",
         "Qwen/Qwen3-8B",
         "unsloth/Qwen3-8B-bnb-4bit",
     ),
-    "unsloth/Qwen3-14B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-14B-unsloth-bnb-4bit": (
         "unsloth/Qwen3-14B",
         "Qwen/Qwen3-14B",
         "unsloth/Qwen3-14B-bnb-4bit",
     ),
-    "unsloth/Qwen3-32B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-32B-unsloth-bnb-4bit": (
         "unsloth/Qwen3-32B",
         "Qwen/Qwen3-32B",
         "unsloth/Qwen3-32B-bnb-4bit",
     ),
-    "unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit": (
         "unsloth/Qwen3-30B-A3B",
         "Qwen/Qwen3-30B-A3B",
         "unsloth/Qwen3-30B-A3B-bnb-4bit",
     ),
-    "unsloth/Qwen3-0.6B-Base-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-0.6B-Base-unsloth-bnb-4bit": (
         "unsloth/Qwen3-0.6B-Base",
         "Qwen/Qwen3-0.6B-Base",
         "unsloth/Qwen3-0.6B-Base-bnb-4bit",
     ),
-    "unsloth/Qwen3-1.7B-Base-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-1.7B-Base-unsloth-bnb-4bit": (
         "unsloth/Qwen3-1.7B-Base",
         "Qwen/Qwen3-1.7B-Base",
         "unsloth/Qwen3-1.7B-Base-bnb-4bit",
     ),
-    "unsloth/Qwen3-4B-Base-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-4B-Base-unsloth-bnb-4bit": (
         "unsloth/Qwen3-4B-Base",
         "Qwen/Qwen3-4B-Base",
         "unsloth/Qwen3-4B-Base-bnb-4bit",
     ),
-    "unsloth/Qwen3-8B-Base-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-8B-Base-unsloth-bnb-4bit": (
         "unsloth/Qwen3-8B-Base",
         "Qwen/Qwen3-8B-Base",
         "unsloth/Qwen3-8B-Base-bnb-4bit",
     ),
-    "unsloth/Qwen3-14B-Base-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-14B-Base-unsloth-bnb-4bit": (
         "unsloth/Qwen3-14B-Base",
         "Qwen/Qwen3-14B-Base",
         "unsloth/Qwen3-14B-Base-bnb-4bit",
     ),
-    "unsloth/Qwen3-30B-A3B-Base-bnb-4bit" : (
+    "unsloth/Qwen3-30B-A3B-Base-bnb-4bit": (
         "unsloth/Qwen3-30B-A3B-Base",
         "Qwen/Qwen3-30B-A3B-Base",
     ),
-    "unsloth/phi-4-reasoning-unsloth-bnb-4bit" : (
+    "unsloth/phi-4-reasoning-unsloth-bnb-4bit": (
         "unsloth/phi-4-reasoning",
         "microsoft/Phi-4-reasoning",
         "unsloth/phi-4-reasoning-bnb-4bit",
     ),
-    "unsloth/phi-4-reasoning-plus-unsloth-bnb-4bit" : (
+    "unsloth/phi-4-reasoning-plus-unsloth-bnb-4bit": (
         "unsloth/phi-4-reasoning-plus",
         "microsoft/Phi-4-reasoning-plus",
         "unsloth/phi-4-reasoning-plus-bnb-4bit",
     ),
-    "unsloth/phi-4-mini-reasoning-unsloth-bnb-4bit" : (
+    "unsloth/phi-4-mini-reasoning-unsloth-bnb-4bit": (
         "unsloth/phi-4-mini-reasoning",
         "microsoft/Phi-4-mini-reasoning",
         "unsloth/phi-4-mini-reasoning-bnb-4bit",
     ),
-    "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit" : (
+    "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit": (
         "unsloth/Phi-4-mini-instruct",
         "microsoft/Phi-4-mini-instruct",
         "unsloth/Phi-4-mini-instruct-bnb-4bit",
     ),
-    "unsloth/orpheus-3b-0.1-pretrained-unsloth-bnb-4bit" : (
+    "unsloth/orpheus-3b-0.1-pretrained-unsloth-bnb-4bit": (
         "unsloth/orpheus-3b-0.1-pretrained",
         "canopylabs/orpheus-3b-0.1-pretrained",
         "unsloth/orpheus-3b-0.1-pretrained-bnb-4bit",
     ),
-    "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit" : (
+    "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit": (
         "unsloth/orpheus-3b-0.1-ft",
         "canopylabs/orpheus-3b-0.1-ft",
         "unsloth/orpheus-3b-0.1-ft-bnb-4bit",
     ),
-    "unsloth/csm-1b" : (
+    "unsloth/csm-1b": (
         "unsloth/csm-1b",
         "sesame/csm-1b",
     ),
-    "unsloth/whisper-large-v3" : (
+    "unsloth/whisper-large-v3": (
         "unsloth/whisper-large-v3",
         "openai/whisper-large-v3",
     ),
-    "unsloth/whisper-large-v3-turbo" : (
+    "unsloth/whisper-large-v3-turbo": (
         "unsloth/whisper-large-v3-turbo",
         "openai/whisper-large-v3-turbo",
     ),
-    "unsloth/whisper-small" : (
+    "unsloth/whisper-small": (
         "unsloth/whisper-small",
         "openai/whisper-small",
     ),
-    "unsloth/CrisperWhisper" : (
+    "unsloth/CrisperWhisper": (
         "unsloth/CrisperWhisper",
         "nyrahealth/CrisperWhisper",
     ),
-    "unsloth/Llasa-1B" : (
+    "unsloth/Llasa-1B": (
         "unsloth/Llasa-1B",
         "HKUSTAudio/Llasa-1B",
     ),
-    "unsloth/Spark-TTS-0.5B" : (
+    "unsloth/Spark-TTS-0.5B": (
         "unsloth/Spark-TTS-0.5B",
         "SparkAudio/Spark-TTS-0.5B",
     ),
-    "unsloth/Llama-OuteTTS-1.0-1B" : (
+    "unsloth/Llama-OuteTTS-1.0-1B": (
         "unsloth/Llama-OuteTTS-1.0-1B",
         "OuteAI/Llama-OuteTTS-1.0-1B",
     ),
-    "unsloth/medgemma-4b-it-unsloth-bnb-4bit" : (
+    "unsloth/medgemma-4b-it-unsloth-bnb-4bit": (
         "unsloth/medgemma-4b-it",
         "google/medgemma-4b-it",
         "unsloth/medgemma-4b-it-bnb-4bit",
     ),
-    "unsloth/medgemma-27b-text-it-unsloth-bnb-4bit" : (
+    "unsloth/medgemma-27b-text-it-unsloth-bnb-4bit": (
         "unsloth/medgemma-27b-text-it",
         "google/medgemma-27b-text-it",
         "unsloth/medgemma-27b-text-it-bnb-4bit",
     ),
-    "unsloth/Devstral-Small-2505-unsloth-bnb-4bit" : (
+    "unsloth/Devstral-Small-2505-unsloth-bnb-4bit": (
         "unsloth/Devstral-Small-2505",
         "mistralai/Devstral-Small-2505",
         "unsloth/Devstral-Small-2505-bnb-4bit",
     ),
-    "unsloth/DeepSeek-R1-0528-Qwen3-8B-unsloth-bnb-4bit" : (
+    "unsloth/DeepSeek-R1-0528-Qwen3-8B-unsloth-bnb-4bit": (
         "unsloth/DeepSeek-R1-0528-Qwen3-8B",
         "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
         "unsloth/DeepSeek-R1-0528-Qwen3-8B-bnb-4bit",
     ),
-    "unsloth/Magistral-Small-2506-unsloth-bnb-4bit" : (
+    "unsloth/Magistral-Small-2506-unsloth-bnb-4bit": (
         "unsloth/Magistral-Small-2506",
         "mistralai/Magistral-Small-2506",
         "unsloth/Magistral-Small-2506-bnb-4bit",
     ),
-    "unsloth/Mistral-Small-3.2-24B-Instruct-2506-unsloth-bnb-4bit" : (
+    "unsloth/Mistral-Small-3.2-24B-Instruct-2506-unsloth-bnb-4bit": (
         "unsloth/Mistral-Small-3.2-24B-Instruct-2506",
         "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
         "unsloth/Mistral-Small-3.2-24B-Instruct-2506-bnb-4bit",
     ),
-    "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit": (
         "unsloth/gemma-3n-E4B-it",
         "google/gemma-3n-E4B-it",
         "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
     ),
-    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit": (
         "unsloth/gemma-3n-E2B-it",
         "google/gemma-3n-E2B-it",
         "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
     ),
-    "unsloth/gemma-3n-E4B-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3n-E4B-unsloth-bnb-4bit": (
         "unsloth/gemma-3n-E4B",
         "google/gemma-3n-E4B",
         "unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
     ),
-    "unsloth/gemma-3n-E2B-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3n-E2B-unsloth-bnb-4bit": (
         "unsloth/gemma-3n-E2B",
         "google/gemma-3n-E2B",
         "unsloth/gemma-3n-E2B-unsloth-bnb-4bit",
     ),
-    "unsloth/Devstral-Small-2507-unsloth-bnb-4bit" : (
+    "unsloth/Devstral-Small-2507-unsloth-bnb-4bit": (
         "unsloth/Devstral-Small-2507",
         "mistralai/Devstral-Small-2507",
         "unsloth/Devstral-Small-2507-bnb-4bit",
     ),
-    "unsloth/Qwen3-30B-A3B-Thinking-2507" : (
+    "unsloth/Qwen3-30B-A3B-Thinking-2507": (
         "unsloth/Qwen3-30B-A3B-Thinking-2507",
         "Qwen/Qwen3-30B-A3B-Thinking-2507",
     ),
-    "unsloth/Qwen3-30B-A3B-Instruct-2507" : (
+    "unsloth/Qwen3-30B-A3B-Instruct-2507": (
         "unsloth/Qwen3-30B-A3B-Instruct-2507",
         "Qwen/Qwen3-30B-A3B-Instruct-2507",
     ),
-    "unsloth/Qwen3-Coder-30B-A3B-Instruct" : (
+    "unsloth/Qwen3-Coder-30B-A3B-Instruct": (
         "unsloth/Qwen3-Coder-30B-A3B-Instruct",
         "Qwen/Qwen3-Coder-30B-A3B-Instruct",
     ),
-    "unsloth/gpt-oss-20b-unsloth-bnb-4bit" : (
+    "unsloth/gpt-oss-20b-unsloth-bnb-4bit": (
         "unsloth/gpt-oss-20b",
         "openai/gpt-oss-20b",
         "unsloth/gpt-oss-20b-unsloth-bnb-4bit",
     ),
-    "unsloth/gpt-oss-120b-unsloth-bnb-4bit" : (
+    "unsloth/gpt-oss-120b-unsloth-bnb-4bit": (
         "unsloth/gpt-oss-120b",
         "openai/gpt-oss-120b",
         "unsloth/gpt-oss-120b-unsloth-bnb-4bit",
     ),
-    "unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit": (
         "unsloth/Qwen3-4B-Instruct-2507",
         "Qwen/Qwen3-4B-Instruct-2507",
         "unsloth/Qwen3-4B-Instruct-2507-bnb-4bit",
     ),
-    "unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit": (
         "unsloth/Qwen3-4B-Thinking-2507",
         "Qwen/Qwen3-4B-Thinking-2507",
         "unsloth/Qwen3-4B-Thinking-2507-bnb-4bit",
     ),
-    "unsloth/gemma-3-270m-it-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3-270m-it-unsloth-bnb-4bit": (
         "unsloth/gemma-3-270m-it",
         "google/gemma-3-270m-it",
         "unsloth/gemma-3-270m-it-bnb-4bit",
     ),
-    "unsloth/gemma-3-270m-unsloth-bnb-4bit" : (
+    "unsloth/gemma-3-270m-unsloth-bnb-4bit": (
         "unsloth/gemma-3-270m",
         "google/gemma-3-270m",
         "unsloth/gemma-3-270m-bnb-4bit",
     ),
-    "unsloth/Magistral-Small-2507-unsloth-bnb-4bit" : (
+    "unsloth/Magistral-Small-2507-unsloth-bnb-4bit": (
         "unsloth/Magistral-Small-2507",
         "mistralai/Magistral-Small-2507",
         "unsloth/Magistral-Small-2507-bnb-4bit",
     ),
-    "unsloth/Magistral-Small-2509-unsloth-bnb-4bit" : (
+    "unsloth/Magistral-Small-2509-unsloth-bnb-4bit": (
         "unsloth/Magistral-Small-2509",
         "mistralai/Magistral-Small-2509",
         "unsloth/Magistral-Small-2509-bnb-4bit",
     ),
-    "unsloth/Apertus-70B-Instruct-2509-unsloth-bnb-4bit" : (
+    "unsloth/Apertus-70B-Instruct-2509-unsloth-bnb-4bit": (
         "unsloth/Apertus-70B-Instruct-2509",
         "swiss-ai/Apertus-70B-2509",
         "unsloth/Apertus-70B-Instruct-2509-unsloth-bnb-4bit",
     ),
-    "unsloth/Apertus-8B-Instruct-2509-unsloth-bnb-4bit" : (
+    "unsloth/Apertus-8B-Instruct-2509-unsloth-bnb-4bit": (
         "unsloth/Apertus-8B-Instruct-2509",
         "swiss-ai/Apertus-8B-2509",
         "unsloth/Apertus-8B-Instruct-2509-unsloth-bnb-4bit",
     ),
-    "unsloth/granite-4.0-micro-unsloth-bnb-4bit" : (
+    "unsloth/granite-4.0-micro-unsloth-bnb-4bit": (
         "unsloth/granite-4.0-micro",
         "ibm-granite/granite-4.0-micro",
         "unsloth/granite-4.0-micro-bnb-4bit",
     ),
-    "unsloth/granite-4.0-h-micro-unsloth-bnb-4bit" : (
+    "unsloth/granite-4.0-h-micro-unsloth-bnb-4bit": (
         "unsloth/granite-4.0-h-micro",
         "ibm-granite/granite-4.0-h-micro",
         "unsloth/granite-4.0-h-micro-bnb-4bit",
     ),
-    "unsloth/granite-4.0-micro-base-unsloth-bnb-4bit" : (
+    "unsloth/granite-4.0-micro-base-unsloth-bnb-4bit": (
         "unsloth/granite-4.0-micro-base",
         "ibm-granite/granite-4.0-micro-base",
         "unsloth/granite-4.0-micro-base-bnb-4bit",
     ),
-    "unsloth/granite-4.0-h-micro-base-unsloth-bnb-4bit" : (
+    "unsloth/granite-4.0-h-micro-base-unsloth-bnb-4bit": (
         "unsloth/granite-4.0-h-micro-base",
         "ibm-granite/granite-4.0-h-micro-base",
         "unsloth/granite-4.0-h-micro-base-bnb-4bit",
     ),
-    "unsloth/granite-4.0-h-tiny" : (
+    "unsloth/granite-4.0-h-tiny": (
         "unsloth/granite-4.0-h-tiny",
         "ibm-granite/granite-4.0-h-tiny",
     ),
-    "unsloth/granite-4.0-h-small" : (
+    "unsloth/granite-4.0-h-small": (
         "unsloth/granite-4.0-h-small",
         "ibm-granite/granite-4.0-h-small",
     ),
-    "unsloth/granite-4.0-h-tiny-base" : (
+    "unsloth/granite-4.0-h-tiny-base": (
         "unsloth/granite-4.0-h-tiny-base",
         "ibm-granite/granite-4.0-h-tiny-base",
     ),
-    "unsloth/granite-4.0-h-small-base" : (
+    "unsloth/granite-4.0-h-small-base": (
         "unsloth/granite-4.0-h-small-base",
         "ibm-granite/granite-4.0-h-small-base",
     ),
-    "unsloth/Qwen3-VL-4B-Thinking-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-VL-4B-Thinking-unsloth-bnb-4bit": (
         "unsloth/Qwen3-VL-4B-Thinking",
         "Qwen/Qwen3-VL-4B-Thinking",
         "unsloth/Qwen3-VL-4B-Thinking-bnb-4bit",
     ),
-    "unsloth/Qwen3-VL-8B-Thinking-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-VL-8B-Thinking-unsloth-bnb-4bit": (
         "unsloth/Qwen3-VL-8B-Thinking",
         "Qwen/Qwen3-VL-8B-Thinking",
         "unsloth/Qwen3-VL-8B-Thinking-bnb-4bit",
     ),
-    "unsloth/Qwen3-VL-4B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-VL-4B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen3-VL-4B-Instruct",
         "Qwen/Qwen3-VL-4B-Instruct",
         "unsloth/Qwen3-VL-4B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen3-VL-8B-Instruct",
         "Qwen/Qwen3-VL-8B-Instruct",
         "unsloth/Qwen3-VL-8B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen3-VL-2B-Thinking-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-VL-2B-Thinking-unsloth-bnb-4bit": (
         "unsloth/Qwen3-VL-2B-Thinking",
         "Qwen/Qwen3-VL-2B-Thinking",
         "unsloth/Qwen3-VL-2B-Thinking-bnb-4bit",
     ),
-    "unsloth/Qwen3-VL-2B-Thinking-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-VL-2B-Thinking-unsloth-bnb-4bit": (
         "unsloth/Qwen3-VL-2B-Thinking",
         "Qwen/Qwen3-VL-2B-Thinking",
         "unsloth/Qwen3-VL-2B-Thinking-bnb-4bit",
     ),
-    "unsloth/Qwen3-VL-32B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-VL-32B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen3-VL-32B-Instruct",
         "Qwen/Qwen3-VL-32B-Instruct",
         "unsloth/Qwen3-VL-32B-Instruct-bnb-4bit",
     ),
-    "unsloth/Qwen3-VL-32B-Instruct-unsloth-bnb-4bit" : (
+    "unsloth/Qwen3-VL-32B-Instruct-unsloth-bnb-4bit": (
         "unsloth/Qwen3-VL-32B-Instruct",
         "Qwen/Qwen3-VL-32B-Instruct",
         "unsloth/Qwen3-VL-32B-Instruct-bnb-4bit",
     ),
 }
 
-INT_TO_FLOAT_MAPPER  = {}
-FLOAT_TO_INT_MAPPER  = {}
+INT_TO_FLOAT_MAPPER = {}
+FLOAT_TO_INT_MAPPER = {}
 MAP_TO_UNSLOTH_16bit = {}
 
 for key, values in __INT_TO_FLOAT_MAPPER.items():
@@ -1063,14 +1042,12 @@
 
     for value in values:
         FLOAT_TO_INT_MAPPER[value] = key
-    pass
 
     # Map to Unsloth version for 16bit versions
     if len(values) == 2:
         if values[0].startswith("unsloth"):
             MAP_TO_UNSLOTH_16bit[values[1]] = values[0]
             MAP_TO_UNSLOTH_16bit[values[1].lower()] = values[0]
-        pass
     elif len(values) == 3:
         # Dynamic Unsloth quantization
         if values[0].startswith("unsloth"):
@@ -1078,8 +1055,6 @@
             MAP_TO_UNSLOTH_16bit[values[1].lower()] = values[0]
             MAP_TO_UNSLOTH_16bit[values[2]] = values[0]
             MAP_TO_UNSLOTH_16bit[values[2].lower()] = values[0]
-        pass
-    pass
 
     # Get lowercased
     lowered_key = key.lower()
@@ -1087,5 +1062,3 @@
 
     for value in values:
         FLOAT_TO_INT_MAPPER[value.lower()] = lowered_key
-    pass
-pass
diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py
index 4db712b07..249501f38 100644
--- a/unsloth/models/mistral.py
+++ b/unsloth/models/mistral.py
@@ -27,6 +27,7 @@
     MistralModel,
     MistralForCausalLM,
 )
+
 # For Pytorch 2.1.1
 try:
     from transformers.models.mistral.modeling_mistral import (
@@ -34,26 +35,25 @@
         MistralFlashAttention2,
     )
 except:
-    MistralSdpaAttention   = MistralAttention
+    MistralSdpaAttention = MistralAttention
     MistralFlashAttention2 = MistralAttention
-pass
 from unsloth_zoo.utils import Version, _get_dtype
 
 
 def MistralAttention_fast_forward(
     self,
-    hidden_states:       torch.Tensor,
-    causal_mask:         Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:      Optional[torch.Tensor] = None,
-    position_ids:        Optional[torch.LongTensor] = None,
-    past_key_value:      Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:   bool = False,
-    use_cache:           bool = False,
-    padding_mask:        Optional[torch.LongTensor] = None,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+    padding_mask: Optional[torch.LongTensor] = None,
     position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-    *args, **kwargs,
+    *args,
+    **kwargs,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-
     # Clear inference
     if hasattr(self, "paged_attention"):
         del self.paged_attention_K
@@ -63,18 +63,17 @@ def MistralAttention_fast_forward(
         del self.temp_KV
         del self.RH_Q
         del self.attention
-    pass
 
     bsz, q_len, _ = hidden_states.size()
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
-    assert(n_kv_heads * n_groups == n_heads)
+    head_dim = self.head_dim
+    assert n_kv_heads * n_groups == n_heads
 
     Q, K, V = self.apply_qkv(self, hidden_states)
-    Q = Q.view(bsz, q_len, n_heads,    head_dim).transpose(1, 2)
+    Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
     K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
     V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
 
@@ -90,16 +89,14 @@ def MistralAttention_fast_forward(
         Q, K = fast_rope_embedding(Q, K, cos, sin)
     else:
         Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
-    pass
 
     if past_key_value is not None:
         K = torch.cat([past_key_value[0], K], dim = 2)
         V = torch.cat([past_key_value[1], V], dim = 2)
-    pass
     past_key_value = (K, V) if use_cache else None
 
     # Attention module
-    if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None):
+    if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None:
         # Xformers memory efficient attention
         Q = Q.transpose(1, 2)
         K = K.transpose(1, 2)
@@ -110,8 +107,8 @@ def MistralAttention_fast_forward(
         has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
 
         # Group query attention
-        K = K  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
-        V = V  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
+        K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+        V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
         K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
         V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
         if hidden_states.requires_grad:
@@ -122,7 +119,6 @@ def MistralAttention_fast_forward(
                 Q = Q.view(1, Q_M, n_heads, head_dim)
                 K = K.view(1, K_M, n_heads, head_dim)
                 V = V.view(1, V_M, n_heads, head_dim)
-            pass
         else:
             # Xformers does support the forward pass though
             Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
@@ -131,8 +127,6 @@ def MistralAttention_fast_forward(
                 Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim)
                 K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
                 V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
-            pass
-        pass
 
         A = xformers_attention(Q, K, V, attn_bias = causal_mask)
         A = A.view(bsz, q_len, n_heads, head_dim)
@@ -158,16 +152,16 @@ def MistralAttention_fast_forward(
         Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
         # Needs (batch_size, n_heads, seq_len, head_dim)
         # is_casual and attention_mask must not be both set!
-        A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
+        A = scaled_dot_product_attention(
+            Q, K, V, attn_mask = attention_mask, is_causal = False
+        )
         # Go back to (batch_size, seq_len, n_heads, head_dim)
         A = A.transpose(1, 2).contiguous()
-    pass
 
-    attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
+    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
     attn_output = self.apply_o(self, attn_output)
     attn_weights = None
     return attn_output, attn_weights, past_key_value
-pass
 
 
 def MistralForCausalLM_fast_forward(
@@ -185,44 +179,60 @@ def MistralForCausalLM_fast_forward(
     return_dict: Optional[bool] = None,
     num_logits_to_keep: Optional[int] = 0,
     logits_to_keep: Optional[int] = 0,
-    *args, **kwargs,
+    *args,
+    **kwargs,
 ) -> Union[Tuple, CausalLMOutputWithPast]:
-
     if causal_mask is None and past_key_values is None:
         bsz, q_len = input_ids.shape
         sliding_window = getattr(self.config, "sliding_window", None)
 
         if HAS_XFORMERS:
             # Always create causal mask for xformers
-            if sliding_window is None or sliding_window == "null" or sliding_window <= 0:
+            if (
+                sliding_window is None
+                or sliding_window == "null"
+                or sliding_window <= 0
+            ):
                 causal_mask = xformers.attn_bias.LowerTriangularMask()
             elif q_len <= sliding_window:
                 causal_mask = xformers.attn_bias.LowerTriangularMask()
             else:
-                causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\
-                    .from_seqlens([q_len]*bsz)\
-                    .make_local_attention(window_size = sliding_window)
+                causal_mask = xformers.attn_bias.BlockDiagonalCausalMask.from_seqlens(
+                    [q_len] * bsz
+                ).make_local_attention(window_size = sliding_window)
 
             # If attention_mask exists, it will be handled in the attention forward
 
         else:
             # Not using xformers - need to create attention masks
-            if sliding_window is None or sliding_window == "null" or sliding_window <= 0 or q_len <= sliding_window:
+            if (
+                sliding_window is None
+                or sliding_window == "null"
+                or sliding_window <= 0
+                or q_len <= sliding_window
+            ):
                 # Fully causal mask
-                causal_mask_values = torch.triu(torch.full((q_len, q_len), -torch.inf, device=input_ids.device), diagonal=1)
+                causal_mask_values = torch.triu(
+                    torch.full((q_len, q_len), -torch.inf, device = input_ids.device),
+                    diagonal = 1,
+                )
             else:
                 # Sliding window attention
-                q_indices = torch.arange(q_len, device=input_ids.device).view(-1, 1)
-                k_indices = torch.arange(q_len, device=input_ids.device).view(1, -1)
+                q_indices = torch.arange(q_len, device = input_ids.device).view(-1, 1)
+                k_indices = torch.arange(q_len, device = input_ids.device).view(1, -1)
 
                 causal_bool_mask = k_indices <= q_indices
                 window_bool_mask = (q_indices - k_indices) < sliding_window
 
-                causal_mask_values = torch.where(causal_bool_mask & window_bool_mask, 0.0, -torch.inf)
+                causal_mask_values = torch.where(
+                    causal_bool_mask & window_bool_mask, 0.0, -torch.inf
+                )
 
             # Combine with existing attention_mask if present
             if attention_mask is None:
-                attention_mask = causal_mask_values[None, None, :, :].expand(bsz, 1, q_len, q_len)
+                attention_mask = causal_mask_values[None, None, :, :].expand(
+                    bsz, 1, q_len, q_len
+                )
             else:
                 # attention_mask should be [bsz, 1, q_len, q_len] or broadcastable
                 # Add causal mask to existing attention mask
@@ -232,13 +242,23 @@ def MistralForCausalLM_fast_forward(
                     attention_mask = attention_mask.expand(bsz, 1, q_len, q_len)
                 attention_mask = attention_mask + causal_mask_values[None, None, :, :]
 
-            attention_mask = attention_mask.to(dtype=_get_dtype(dtype_from_config(self.config)))
+            attention_mask = attention_mask.to(
+                dtype = _get_dtype(dtype_from_config(self.config))
+            )
 
-    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+    output_attentions = (
+        output_attentions
+        if output_attentions is not None
+        else self.config.output_attentions
+    )
     output_hidden_states = (
-        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        output_hidden_states
+        if output_hidden_states is not None
+        else self.config.output_hidden_states
+    )
+    return_dict = (
+        return_dict if return_dict is not None else self.config.use_return_dict
     )
-    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
     # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
     self.model._has_no_labels = labels is None
@@ -264,7 +284,6 @@ def MistralForCausalLM_fast_forward(
             output_hidden_states = output_hidden_states,
             return_dict = return_dict,
         )
-    pass
 
     hidden_states = outputs[0]
 
@@ -274,7 +293,8 @@ def MistralForCausalLM_fast_forward(
 
     # Move items to same device as lm_head
     hidden_states = hidden_states.to(lm_head_device)
-    if labels is not None: labels = labels.to(lm_head_device)
+    if labels is not None:
+        labels = labels.to(lm_head_device)
 
     # If we are in GRPO mode, return raw hidden states
     if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1":
@@ -288,13 +308,14 @@ def MistralForCausalLM_fast_forward(
             hidden_states = outputs.hidden_states,
             attentions = outputs.attentions,
         )
-    pass
 
     if bsz == 1 and q_len == 1:
         logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
         logits = logits.unsqueeze(0).unsqueeze(0)
     elif num_logits_to_keep != 0:
-        logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype))
+        logits = self.lm_head(
+            hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype)
+        )
     else:
         RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
         # < 1024 Normal Unsloth uses less VRAM!
@@ -303,7 +324,9 @@ def MistralForCausalLM_fast_forward(
             RETURN_LOGITS = False
 
         if not RETURN_LOGITS and labels is not None:
-            n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)
+            n_items = kwargs.get("num_items_in_batch", None) or kwargs.get(
+                "n_items", None
+            )
             logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
 
             # loss = fused_linear_cross_entropy(
@@ -314,17 +337,17 @@ def MistralForCausalLM_fast_forward(
             #     logit_softcapping = logit_softcapping,
             # )
             loss = unsloth_fused_ce_loss(
-                trainer              = None,
-                hidden_states        = hidden_states,
-                lm_head_weight       = lm_head,
-                lm_head_bias         = None,
-                labels               = labels,
-                mask                 = None,
-                n_items              = n_items,
-                scaling              = getattr(self, "accelerator_scaler", None),
-                target_gb            = None,
-                torch_compile        = True,
-                logit_softcapping    = logit_softcapping,
+                trainer = None,
+                hidden_states = hidden_states,
+                lm_head_weight = lm_head,
+                lm_head_bias = None,
+                labels = labels,
+                mask = None,
+                n_items = n_items,
+                scaling = getattr(self, "accelerator_scaler", None),
+                target_gb = None,
+                torch_compile = True,
+                logit_softcapping = logit_softcapping,
             )
             if not return_dict:
                 output = (logits,) + outputs[1:]
@@ -338,9 +361,7 @@ def MistralForCausalLM_fast_forward(
                 attentions = outputs.attentions,
             )
             return output
-        pass
         logits = self.lm_head(hidden_states.to(lm_head.dtype))
-    pass
     logits = logits.to(_get_dtype(dtype_from_config(self.config)))
 
     loss = None
@@ -355,11 +376,11 @@ def MistralForCausalLM_fast_forward(
         shift_labels[..., :-1] = labels[..., 1:]
         shift_labels[..., -1] = -100
         loss = fast_cross_entropy_loss(
-            logits  = shift_logits,
-            labels  = shift_labels,
-            n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None),
+            logits = shift_logits,
+            labels = shift_labels,
+            n_items = kwargs.get("num_items_in_batch", None)
+            or kwargs.get("n_items", None),
         )
-    pass
 
     if not return_dict:
         output = (logits,) + outputs[1:]
@@ -372,7 +393,6 @@ def MistralForCausalLM_fast_forward(
         hidden_states = outputs.hidden_states,
         attentions = outputs.attentions,
     )
-pass
 
 
 # Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
@@ -390,33 +410,30 @@ def patch_mistral_nemo_attention(function):
         "self.o_proj = nn.Linear(self.config.num_attention_heads * self.head_dim, self.config.hidden_size, bias=False)",
     )
     return function
-pass
 
 
 class FastMistralModel(FastLlamaModel):
-
     @staticmethod
     def pre_patch():
         init_name, function = patch_linear_scaling(
-            model_name         = "mistral",
-            rope_module        = LlamaRotaryEmbedding,
+            model_name = "mistral",
+            rope_module = LlamaRotaryEmbedding,
             scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
-            attention_module   = MistralAttention,
+            attention_module = MistralAttention,
         )
         # Just for Mistral Nemo models!
         if function is not None and init_name is not None:
             function = patch_mistral_nemo_attention(function)
             # if True:#init_name is not None:
             exec(function, globals())
-            MistralAttention.__init__  = eval(init_name)
-        pass
-        MistralAttention      .forward = MistralAttention_fast_forward
-        MistralSdpaAttention  .forward = MistralAttention_fast_forward
+            MistralAttention.__init__ = eval(init_name)
+        MistralAttention.forward = MistralAttention_fast_forward
+        MistralSdpaAttention.forward = MistralAttention_fast_forward
         MistralFlashAttention2.forward = MistralAttention_fast_forward
-        MistralDecoderLayer   .forward = LlamaDecoderLayer_fast_forward
-        MistralModel          .forward = LlamaModel_fast_forward
-        MistralForCausalLM    .forward = MistralForCausalLM_fast_forward
-        PeftModelForCausalLM  .forward = PeftModel_fast_forward
+        MistralDecoderLayer.forward = LlamaDecoderLayer_fast_forward
+        MistralModel.forward = LlamaModel_fast_forward
+        MistralForCausalLM.forward = MistralForCausalLM_fast_forward
+        PeftModelForCausalLM.forward = PeftModel_fast_forward
         fix_prepare_inputs_for_generation(MistralForCausalLM)
 
         # Solves https://github.com/unslothai/unsloth/issues/168
@@ -425,39 +442,38 @@ def pre_patch():
         # https://github.com/huggingface/transformers/pull/27931
         # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
         import transformers.models.mistral.modeling_mistral
-        transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding = LlamaRotaryEmbedding
-        return
-    pass
 
+        transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding = (
+            LlamaRotaryEmbedding
+        )
+        return
 
     @staticmethod
     def from_pretrained(
-        model_name        = "unsloth/mistral-7b-bnb-4bit",
-        max_seq_length    = None,
-        dtype             = None,
-        load_in_4bit      = True,
-        token             = None,
-        device_map        = "sequential",
-        rope_scaling      = None, # Mistral does not support RoPE scaling
-        fix_tokenizer     = True,
-        model_patcher     = None,
-        tokenizer_name    = None,
+        model_name = "unsloth/mistral-7b-bnb-4bit",
+        max_seq_length = None,
+        dtype = None,
+        load_in_4bit = True,
+        token = None,
+        device_map = "sequential",
+        rope_scaling = None,  # Mistral does not support RoPE scaling
+        fix_tokenizer = True,
+        model_patcher = None,
+        tokenizer_name = None,
         trust_remote_code = False,
         **kwargs,
     ):
         return FastLlamaModel.from_pretrained(
-            model_name        = model_name,
-            max_seq_length    = max_seq_length,
-            dtype             = dtype,
-            load_in_4bit      = load_in_4bit,
-            token             = token,
-            device_map        = device_map,
-            rope_scaling      = rope_scaling,
-            fix_tokenizer     = fix_tokenizer,
-            model_patcher     = FastMistralModel,
-            tokenizer_name    = tokenizer_name,
+            model_name = model_name,
+            max_seq_length = max_seq_length,
+            dtype = dtype,
+            load_in_4bit = load_in_4bit,
+            token = token,
+            device_map = device_map,
+            rope_scaling = rope_scaling,
+            fix_tokenizer = fix_tokenizer,
+            model_patcher = FastMistralModel,
+            tokenizer_name = tokenizer_name,
             trust_remote_code = trust_remote_code,
             **kwargs,
         )
-    pass
-pass
diff --git a/unsloth/models/qwen2.py b/unsloth/models/qwen2.py
index b07391365..64ed707fe 100644
--- a/unsloth/models/qwen2.py
+++ b/unsloth/models/qwen2.py
@@ -23,6 +23,7 @@
     Qwen2Model,
     Qwen2ForCausalLM,
 )
+
 # For Pytorch 2.1.1
 try:
     from transformers.models.qwen2.modeling_qwen2 import (
@@ -30,31 +31,30 @@
         Qwen2FlashAttention2,
     )
 except:
-    Qwen2SdpaAttention   = Qwen2Attention
+    Qwen2SdpaAttention = Qwen2Attention
     Qwen2FlashAttention2 = Qwen2Attention
-pass
 
 
 class FastQwen2Model(FastLlamaModel):
-
     @staticmethod
     def pre_patch():
         init_name, function = patch_linear_scaling(
-            model_name         = "qwen2",
-            rope_module        = LlamaRotaryEmbedding,
+            model_name = "qwen2",
+            rope_module = LlamaRotaryEmbedding,
             scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
-            attention_module   = Qwen2Attention,
+            attention_module = Qwen2Attention,
         )
         if init_name is not None:
             exec(function, globals())
-            Qwen2Attention.__init__  = eval(init_name)
-        pass
-        Qwen2Attention      .forward = LlamaAttention_fast_forward
-        Qwen2SdpaAttention  .forward = LlamaAttention_fast_forward
+            Qwen2Attention.__init__ = eval(init_name)
+        Qwen2Attention.forward = LlamaAttention_fast_forward
+        Qwen2SdpaAttention.forward = LlamaAttention_fast_forward
         Qwen2FlashAttention2.forward = LlamaAttention_fast_forward
-        Qwen2DecoderLayer   .forward = LlamaDecoderLayer_fast_forward
-        Qwen2Model          .forward = LlamaModel_fast_forward
-        Qwen2ForCausalLM    .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
+        Qwen2DecoderLayer.forward = LlamaDecoderLayer_fast_forward
+        Qwen2Model.forward = LlamaModel_fast_forward
+        Qwen2ForCausalLM.forward = CausalLM_fast_forward(
+            LlamaModel_fast_forward_inference
+        )
         PeftModelForCausalLM.forward = PeftModel_fast_forward
         fix_prepare_inputs_for_generation(Qwen2ForCausalLM)
 
@@ -64,39 +64,38 @@ def pre_patch():
         # https://github.com/huggingface/transformers/pull/27931
         # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
         import transformers.models.qwen2.modeling_qwen2
-        transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = LlamaRotaryEmbedding
-        return
-    pass
 
+        transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = (
+            LlamaRotaryEmbedding
+        )
+        return
 
     @staticmethod
     def from_pretrained(
-        model_name        = "Qwen/Qwen2-7B",
-        max_seq_length    = 4096,
-        dtype             = None,
-        load_in_4bit      = True,
-        token             = None,
-        device_map        = "sequential",
-        rope_scaling      = None, # Qwen2 does not support RoPE scaling
-        fix_tokenizer     = True,
-        model_patcher     = None,
-        tokenizer_name    = None,
+        model_name = "Qwen/Qwen2-7B",
+        max_seq_length = 4096,
+        dtype = None,
+        load_in_4bit = True,
+        token = None,
+        device_map = "sequential",
+        rope_scaling = None,  # Qwen2 does not support RoPE scaling
+        fix_tokenizer = True,
+        model_patcher = None,
+        tokenizer_name = None,
         trust_remote_code = False,
         **kwargs,
     ):
         return FastLlamaModel.from_pretrained(
-            model_name        = model_name,
-            max_seq_length    = max_seq_length,
-            dtype             = dtype,
-            load_in_4bit      = load_in_4bit,
-            token             = token,
-            device_map        = device_map,
-            rope_scaling      = rope_scaling,
-            fix_tokenizer     = fix_tokenizer,
-            model_patcher     = FastQwen2Model,
-            tokenizer_name    = tokenizer_name,
+            model_name = model_name,
+            max_seq_length = max_seq_length,
+            dtype = dtype,
+            load_in_4bit = load_in_4bit,
+            token = token,
+            device_map = device_map,
+            rope_scaling = rope_scaling,
+            fix_tokenizer = fix_tokenizer,
+            model_patcher = FastQwen2Model,
+            tokenizer_name = tokenizer_name,
             trust_remote_code = trust_remote_code,
             **kwargs,
         )
-    pass
-pass
diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py
index 95d723200..3d905fe06 100644
--- a/unsloth/models/qwen3.py
+++ b/unsloth/models/qwen3.py
@@ -21,6 +21,7 @@
     LlamaLinearScalingRotaryEmbedding,
     _LlamaModel_fast_forward_inference,
 )
+
 try:
     from transformers.models.qwen3.modeling_qwen3 import (
         Qwen3Attention,
@@ -30,17 +31,19 @@
     )
 except:
     transformers_version = Version(transformers_version)
-    if not transformers_version >= Version("4.50.3"): #TODO: Update when transformers is updated
+    if not transformers_version >= Version(
+        "4.50.3"
+    ):  # TODO: Update when transformers is updated
         raise ImportError(
-            f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3 and Qwen3Moe.\n"\
-            f"The minimum required version is 4.50.3.\n"\
-            f'Try `pip install --upgrade "transformers>=4.50.3"`\n'\
-            f"to obtain the latest transformers build, then restart this session."\
+            f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3 and Qwen3Moe.\n"
+            f"The minimum required version is 4.50.3.\n"
+            f'Try `pip install --upgrade "transformers>=4.50.3"`\n'
+            f"to obtain the latest transformers build, then restart this session."
         )
-    pass
 from transformers.modeling_attn_mask_utils import (
     _prepare_4d_causal_attention_mask_for_sdpa,
 )
+
 # For Pytorch 2.1.1
 try:
     from transformers.models.qwen3.modeling_qwen3 import (
@@ -48,25 +51,24 @@
         Qwen3FlashAttention2,
     )
 except:
-    Qwen3SdpaAttention   = Qwen3Attention
+    Qwen3SdpaAttention = Qwen3Attention
     Qwen3FlashAttention2 = Qwen3Attention
-pass
 
 
 def Qwen3Attention_fast_forward(
     self,
-    hidden_states:       torch.Tensor,
-    causal_mask:         Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:      Optional[torch.Tensor] = None,
-    position_ids:        Optional[torch.LongTensor] = None,
-    past_key_value:      Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:   bool = False,
-    use_cache:           bool = False,
-    padding_mask:        Optional[torch.LongTensor] = None,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+    padding_mask: Optional[torch.LongTensor] = None,
     position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-    *args, **kwargs,
+    *args,
+    **kwargs,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-
     # Clear inference
     if hasattr(self, "paged_attention"):
         del self.paged_attention_K
@@ -76,22 +78,25 @@ def Qwen3Attention_fast_forward(
         del self.temp_KV
         del self.RH_Q
         del self.attention
-    pass
 
     bsz, q_len, _ = hidden_states.size()
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
-    assert(n_kv_heads * n_groups == n_heads)
+    head_dim = self.head_dim
+    assert n_kv_heads * n_groups == n_heads
 
     Q, K, V = self.apply_qkv(self, hidden_states)
-    Q = Q.view(bsz, q_len, n_heads,    head_dim)#.transpose(1, 2) # we will transpose after normalisation
-    K = K.view(bsz, q_len, n_kv_heads, head_dim)#.transpose(1, 2) # we will transpose after normalisation
+    Q = Q.view(
+        bsz, q_len, n_heads, head_dim
+    )  # .transpose(1, 2) # we will transpose after normalisation
+    K = K.view(
+        bsz, q_len, n_kv_heads, head_dim
+    )  # .transpose(1, 2) # we will transpose after normalisation
     V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
 
-    #Qwen3 has QKNorm. This seems to be the only difference from Qwen2.
+    # Qwen3 has QKNorm. This seems to be the only difference from Qwen2.
     # Note that using fast_layernorm_compiled causes issues as the dimensions don't match up.
     # I tried to add a compiled version of the new norm but the numbers don't match up with Transformers
     # TODO: Check on the differences here.
@@ -123,11 +128,10 @@ def Qwen3Attention_fast_forward(
     if past_key_value is not None:
         K = torch.cat([past_key_value[0], K], dim = 2)
         V = torch.cat([past_key_value[1], V], dim = 2)
-    pass
     past_key_value = (K, V) if use_cache else None
 
     # Attention module
-    if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None):
+    if not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None:
         # Xformers memory efficient attention
         Q = Q.transpose(1, 2)
         K = K.transpose(1, 2)
@@ -138,8 +142,8 @@ def Qwen3Attention_fast_forward(
         has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
 
         # Group query attention
-        K = K  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
-        V = V  .view(bsz, kv_seq_len, n_kv_heads,        1, head_dim)
+        K = K.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+        V = V.view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
         K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
         V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
         if hidden_states.requires_grad:
@@ -150,7 +154,6 @@ def Qwen3Attention_fast_forward(
                 Q = Q.view(1, Q_M, n_heads, head_dim)
                 K = K.view(1, K_M, n_heads, head_dim)
                 V = V.view(1, V_M, n_heads, head_dim)
-            pass
         else:
             # Xformers does support the forward pass though
             Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
@@ -159,8 +162,6 @@ def Qwen3Attention_fast_forward(
                 Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim)
                 K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
                 V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
-            pass
-        pass
 
         A = xformers_attention(Q, K, V, attn_bias = causal_mask)
         A = A.view(bsz, q_len, n_heads, head_dim)
@@ -193,67 +194,70 @@ def Qwen3Attention_fast_forward(
         else:
             is_causal = False
 
-        A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = is_causal)
+        A = scaled_dot_product_attention(
+            Q, K, V, attn_mask = attention_mask, is_causal = is_causal
+        )
         # Go back to (batch_size, seq_len, n_heads, head_dim)
         A = A.transpose(1, 2).contiguous()
-    pass
 
-    attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
+    attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
     attn_output = self.apply_o(self, attn_output)
     attn_weights = None
     return attn_output, attn_weights, past_key_value
-pass
+
 
 torch_matmul = torch.matmul
+
+
 def Qwen3Attention_fast_forward_inference(
     self,
-    hidden_states:  torch.Tensor,
+    hidden_states: torch.Tensor,
     past_key_value: Optional[Tuple[torch.Tensor]],
     position_ids,
     do_prefill = False,
     attention_mask = None,
 ):
     """
-        https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
-        Fast inference using KV cache.
-        QK^T can be computed in 4 chunks
+    https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
+    Fast inference using KV cache.
+    QK^T can be computed in 4 chunks
 
-        [Q, q] @ [K, k].T where q, k are the new tokens.
-        [QK^T, Qk^T]
-        [qK^T, qk^T]
+    [Q, q] @ [K, k].T where q, k are the new tokens.
+    [QK^T, Qk^T]
+    [qK^T, qk^T]
 
-        Since the attention mask wipes Qk^T, we just get
-        [QK^T,    0]
-        [qK^T, qk^T]
+    Since the attention mask wipes Qk^T, we just get
+    [QK^T,    0]
+    [qK^T, qk^T]
 
-        Since softmax is row-wise, we get
-        softmax([QK^T,    0])
-        softmax([qK^T, qk^T])
+    Since softmax is row-wise, we get
+    softmax([QK^T,    0])
+    softmax([qK^T, qk^T])
 
-        We then multiply by   [V]
-                              [v]
-        softmax([QK^T,    0]) [softmax(QK^T)V] *
-        softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
+    We then multiply by   [V]
+                          [v]
+    softmax([QK^T,    0]) [softmax(QK^T)V] *
+    softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
 
-        But notice * [softmax(QK^T)V] is just the last attention.
-        We just need to compute the last final row.
+    But notice * [softmax(QK^T)V] is just the last attention.
+    We just need to compute the last final row.
 
-        This means we can pass in a row of Q, but we need to
-        remember K and V, which are called the KV cache.
+    This means we can pass in a row of Q, but we need to
+    remember K and V, which are called the KV cache.
     """
     Xn = hidden_states
     bsz, _, hd = hidden_states.size()
     K1, V1 = past_key_value
     dtype = Xn.dtype
 
-    n_heads    = self.config.num_attention_heads
-    n_groups   = self.num_key_value_groups
+    n_heads = self.config.num_attention_heads
+    n_groups = self.num_key_value_groups
     n_kv_heads = self.config.num_key_value_heads
-    head_dim   = self.head_dim
+    head_dim = self.head_dim
     # assert(n_kv_heads * n_groups == n_heads)
 
     hidden_size = self.config.hidden_size
-    attention_size = n_heads*head_dim
+    attention_size = n_heads * head_dim
     seq_len = K1.shape[-2]
     kv_seq_len = seq_len + 1
 
@@ -261,37 +265,59 @@ def Qwen3Attention_fast_forward_inference(
     # if not hasattr(self, "paged_attention"):
     device = hidden_states.device
     if do_prefill:
-        self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = device)
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
+        self.paged_attention = torch.empty(
+            (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
+            dtype = dtype,
+            device = device,
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
         self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
         self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
-        self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device)
-        self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device)
+        self.temp_QA = torch.empty(
+            (2, bsz, 1, attention_size), dtype = dtype, device = device
+        )
+        self.temp_KV = torch.empty(
+            (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
+        )
         self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
 
         # Mistral Nemo 12b has weird dimensions
         if attention_size != hidden_size:
             self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
         else:
-            self.temp_O = self.temp_QA[1][:,:,:hidden_size]
-        pass
+            self.temp_O = self.temp_QA[1][:, :, :hidden_size]
 
-        self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device)
+        self.attention = torch.empty(
+            (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
+        )
         self.scalar = 1.0 / math_sqrt(self.head_dim)
         self.half_head_dim = head_dim // 2
     elif kv_seq_len >= self.paged_attention.shape[0]:
-        self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
-        self.paged_attention_K = self.paged_attention[:,0]
-        self.paged_attention_V = self.paged_attention[:,1]
-        self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
-    pass
+        self.paged_attention.resize_(
+            (
+                self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
+                2,
+                bsz,
+                n_kv_heads,
+                head_dim,
+            )
+        )
+        self.paged_attention_K = self.paged_attention[:, 0]
+        self.paged_attention_V = self.paged_attention[:, 1]
+        self.attention.resize_(
+            (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
+        )
 
     Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
     Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
     Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
-    Qn = Qn.view(bsz, 1, n_heads,    head_dim)#.transpose(1, 2) # we will transpose after normalisation
-    Kn = Kn.view(bsz, 1, n_kv_heads, head_dim)#.transpose(1, 2) # we will transpose after normalisation
+    Qn = Qn.view(
+        bsz, 1, n_heads, head_dim
+    )  # .transpose(1, 2) # we will transpose after normalisation
+    Kn = Kn.view(
+        bsz, 1, n_kv_heads, head_dim
+    )  # .transpose(1, 2) # we will transpose after normalisation
     Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
 
     Qn = fast_rms_layernorm_inference(self.q_norm, Qn)
@@ -312,16 +338,18 @@ def Qwen3Attention_fast_forward_inference(
     h = self.half_head_dim
 
     RH_Q = self.RH_Q
-    RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
-    RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
-    RH_Q[:,:,:,:h].neg_() # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
+    RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
+    RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
+    RH_Q[:, :, :, :h].neg_()  # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
     Qn *= cos
     Qn.addcmul_(RH_Q, sin)
 
-    RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
-    RH_K[:,:,:,:h] = Kn[:,:,:,h:]
-    RH_K[:,:,:,h:] = Kn[:,:,:,:h]
-    RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
+    RH_K = RH_Q[
+        :, :n_kv_heads, :, :
+    ]  # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+    RH_K[:, :, :, :h] = Kn[:, :, :, h:]
+    RH_K[:, :, :, h:] = Kn[:, :, :, :h]
+    RH_K[:, :, :, :h].neg_()  # torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
     Kn *= cos
     Kn.addcmul_(RH_K, sin)
 
@@ -338,11 +366,10 @@ def Qwen3Attention_fast_forward_inference(
     if sliding_window is not None and kv_seq_len > sliding_window:
         # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
         slicing_tokens = 1 - sliding_window
-        Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
-        Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
+        Knn = Kn[:, :, slicing_tokens:, :]  # .contiguous()
+        Vnn = Vn[:, :, slicing_tokens:, :]  # .contiguous()
     else:
         Knn, Vnn = Kn, Vn
-    pass
 
     # when qlen==vlen and attn_mask is None, we should use causal attention
     Q_len = Qn.shape[-2]
@@ -355,55 +382,70 @@ def Qwen3Attention_fast_forward_inference(
     # Grouped query attention
     _, _, cached_len, _ = Knn.shape
     if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1:
-        Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
-        Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+        Knn = Knn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
+        Vnn = Vnn[:, :, None, :, :].expand(
+            bsz, n_kv_heads, n_groups, cached_len, head_dim
+        )
         Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
         Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
-    pass
     # else:
     #     Knn, Vnn = Knn, Vnn
     # pass
 
     # Attention
     if bsz == 1:
-        Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
+        Qn *= self.scalar  # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
         # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
-        A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
+        A = torch_matmul(
+            Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
+        )
         # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
-        A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
+        A[:] = torch_nn_functional_softmax(
+            A, dim = -1, dtype = torch.float32
+        )  # .to(A.dtype)
         A = torch_matmul(A, Vnn, out = Qn)
     else:
         if SDPA_HAS_GQA:
-            A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal, enable_gqa = True)
+            A = scaled_dot_product_attention(
+                Qn,
+                Knn,
+                Vnn,
+                attn_mask = attention_mask,
+                is_causal = is_causal,
+                enable_gqa = True,
+            )
         else:
-            A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal)
-    pass
+            A = scaled_dot_product_attention(
+                Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal
+            )
     A = A.transpose(1, 2)
     A = A.reshape(bsz, 1, attention_size)
     A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
     return A, (Kn, Vn)
-pass
 
-class FastQwen3Model(FastLlamaModel):
 
+class FastQwen3Model(FastLlamaModel):
     @staticmethod
     def pre_patch():
         init_name, function = patch_linear_scaling(
-            model_name         = "Qwen3",
-            rope_module        = LlamaRotaryEmbedding,
+            model_name = "Qwen3",
+            rope_module = LlamaRotaryEmbedding,
             scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
-            attention_module   = Qwen3Attention,
+            attention_module = Qwen3Attention,
         )
         if init_name is not None:
             exec(function, globals())
-            Qwen3Attention.__init__  = eval(init_name)
-        pass
-        Qwen3Attention      .forward = Qwen3Attention_fast_forward
-        Qwen3SdpaAttention  .forward = Qwen3Attention_fast_forward
+            Qwen3Attention.__init__ = eval(init_name)
+        Qwen3Attention.forward = Qwen3Attention_fast_forward
+        Qwen3SdpaAttention.forward = Qwen3Attention_fast_forward
         Qwen3FlashAttention2.forward = Qwen3Attention_fast_forward
-        Qwen3DecoderLayer   .forward = LlamaDecoderLayer_fast_forward
-        Qwen3Model          .forward = LlamaModel_fast_forward
-        Qwen3ForCausalLM    .forward = CausalLM_fast_forward(_LlamaModel_fast_forward_inference(Qwen3Attention_fast_forward_inference))
+        Qwen3DecoderLayer.forward = LlamaDecoderLayer_fast_forward
+        Qwen3Model.forward = LlamaModel_fast_forward
+        Qwen3ForCausalLM.forward = CausalLM_fast_forward(
+            _LlamaModel_fast_forward_inference(Qwen3Attention_fast_forward_inference)
+        )
         PeftModelForCausalLM.forward = PeftModel_fast_forward
         fix_prepare_inputs_for_generation(Qwen3ForCausalLM)
 
@@ -413,39 +455,38 @@ def pre_patch():
         # https://github.com/huggingface/transformers/pull/27931
         # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
         import transformers.models.qwen3.modeling_qwen3
-        transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding = LlamaRotaryEmbedding
-        return
-    pass
 
+        transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding = (
+            LlamaRotaryEmbedding
+        )
+        return
 
     @staticmethod
-    def from_pretrained(  #TODO: Change after release
-        model_name        = "Qwen/Qwen3-7B",
-        max_seq_length    = 4096,
-        dtype             = None,
-        load_in_4bit      = True,
-        token             = None,
-        device_map        = "sequential",
-        rope_scaling      = None,
-        fix_tokenizer     = True,
-        model_patcher     = None,
-        tokenizer_name    = None,
+    def from_pretrained(  # TODO: Change after release
+        model_name = "Qwen/Qwen3-7B",
+        max_seq_length = 4096,
+        dtype = None,
+        load_in_4bit = True,
+        token = None,
+        device_map = "sequential",
+        rope_scaling = None,
+        fix_tokenizer = True,
+        model_patcher = None,
+        tokenizer_name = None,
         trust_remote_code = False,
         **kwargs,
     ):
         return FastLlamaModel.from_pretrained(
-            model_name        = model_name,
-            max_seq_length    = max_seq_length,
-            dtype             = dtype,
-            load_in_4bit      = load_in_4bit,
-            token             = token,
-            device_map        = device_map,
-            rope_scaling      = rope_scaling,
-            fix_tokenizer     = fix_tokenizer,
-            model_patcher     = FastQwen3Model,
-            tokenizer_name    = tokenizer_name,
+            model_name = model_name,
+            max_seq_length = max_seq_length,
+            dtype = dtype,
+            load_in_4bit = load_in_4bit,
+            token = token,
+            device_map = device_map,
+            rope_scaling = rope_scaling,
+            fix_tokenizer = fix_tokenizer,
+            model_patcher = FastQwen3Model,
+            tokenizer_name = tokenizer_name,
             trust_remote_code = trust_remote_code,
             **kwargs,
         )
-    pass
-pass
diff --git a/unsloth/models/qwen3_moe.py b/unsloth/models/qwen3_moe.py
index b62a742d9..e80a87a95 100644
--- a/unsloth/models/qwen3_moe.py
+++ b/unsloth/models/qwen3_moe.py
@@ -31,6 +31,7 @@
     Qwen3MoeModel,
     Qwen3MoeForCausalLM,
 )
+
 # For Pytorch 2.1.1
 # TODO: Transformers moved to `attention_interface`. So we might not need these anymore
 # try:
@@ -46,26 +47,30 @@
 
 
 torch_nn_functional_softmax = torch.nn.functional.softmax
+
+
 def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = None):
     # adapted from https://github.com/huggingface/transformers/pull/36878/files#diff-0855b77fc27ad9449158a1c74953f909b011c00de7125f7c8e68d0ff209c092aR356-R370
-    
+
     bsz, seq_len, hd = X.shape
     X = X.view(-1, hd)
 
-    router_logits = fast_linear_forward(self.gate_proj, X, out = temp_gate) #pretty much the only change from transformers implementation.
+    router_logits = fast_linear_forward(
+        self.gate_proj, X, out = temp_gate
+    )  # pretty much the only change from transformers implementation.
 
     routing_weights = torch_nn_functional_softmax(router_logits, dim = -1)
-    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
-    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1)
+    routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
     # we cast back to the input dtype
     routing_weights = routing_weights.to(X.dtype)
-    final_X = torch.zeros(
-        (bsz * seq_len, hd), dtype=X.dtype, device=X.device
-    )
+    final_X = torch.zeros((bsz * seq_len, hd), dtype = X.dtype, device = X.device)
 
     # One hot encode the selected experts to create an expert mask
     # this will be used to easily index which expert is going to be sollicitated
-    expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+    expert_mask = torch.nn.functional.one_hot(
+        selected_experts, num_classes = self.num_experts
+    ).permute(2, 1, 0)
 
     # Loop over all available experts in the model and perform the computation on each expert
     for expert_idx in range(self.num_experts):
@@ -76,66 +81,76 @@ def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = Non
         # the current expert. We need to make sure to multiply the output hidden
         # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
         current_state = X[None, top_x].reshape(-1, hd)
-        current_X = expert_layer(current_state) * routing_weights[top_x, idx, None] # Qwen3MoeMLP.forward = fast_swiglu_inference takes care of making this faster. Analogous to Dense models' MLP
+        current_X = (
+            expert_layer(current_state) * routing_weights[top_x, idx, None]
+        )  # Qwen3MoeMLP.forward = fast_swiglu_inference takes care of making this faster. Analogous to Dense models' MLP
 
         # However `index_add_` only support torch tensors for indexing so we'll use
         # the `top_x` tensor here.
         final_X.index_add_(0, top_x, current_X.to(X.dtype))
     final_X = final_X.reshape(bsz, seq_len, hd)
     return final_X, router_logits
-pass
 
 
 def Qwen3MoeDecoderLayer_fast_forward(
     self,
-    hidden_states:        torch.Tensor,
-    causal_mask:          Optional[BlockDiagonalCausalMask] = None,
-    attention_mask:       Optional[torch.Tensor] = None,
-    position_ids:         Optional[torch.LongTensor] = None,
-    past_key_value:       Optional[Tuple[torch.Tensor]] = None,
-    output_attentions:    Optional[bool] = False,
-    output_router_logits:    Optional[bool] = False,
-    use_cache:            Optional[bool] = False,
-    padding_mask:         Optional[torch.LongTensor] = None,
-    position_embeddings:  Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-    *args, **kwargs,
+    hidden_states: torch.Tensor,
+    causal_mask: Optional[BlockDiagonalCausalMask] = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: Optional[bool] = False,
+    output_router_logits: Optional[bool] = False,
+    use_cache: Optional[bool] = False,
+    padding_mask: Optional[torch.LongTensor] = None,
+    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+    *args,
+    **kwargs,
 ):
     residual = hidden_states
 
-    if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
+    if use_cache and hasattr(
+        self, "_flag_for_generation"
+    ):  # past_key_value is not None:
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
+        hidden_states = fast_rms_layernorm_inference(
+            self.input_layernorm, hidden_states
+        )
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states=hidden_states,
-            causal_mask=causal_mask,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_value=past_key_value,
-            output_attentions=output_attentions,
-            use_cache=use_cache,
-            padding_mask=padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
             position_embeddings = position_embeddings,
-            _flag_for_generation=self._flag_for_generation,
+            _flag_for_generation = self._flag_for_generation,
         )
         hidden_states = residual + hidden_states
 
         # MoE Router MLP
         residual = hidden_states
-        hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
-        hidden_states, router_logits = Qwen3MoeSparseMoeBlock_fast_forward(self.mlp, hidden_states)
+        hidden_states = fast_rms_layernorm_inference(
+            self.post_attention_layernorm, hidden_states
+        )
+        hidden_states, router_logits = Qwen3MoeSparseMoeBlock_fast_forward(
+            self.mlp, hidden_states
+        )
         hidden_states = residual + hidden_states
     else:
         residual = hidden_states
         hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
-            hidden_states=hidden_states,
-            causal_mask=causal_mask,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_value=past_key_value,
-            output_attentions=output_attentions,
-            use_cache=use_cache,
-            padding_mask=padding_mask,
+            hidden_states = hidden_states,
+            causal_mask = causal_mask,
+            attention_mask = attention_mask,
+            position_ids = position_ids,
+            past_key_value = past_key_value,
+            output_attentions = output_attentions,
+            use_cache = use_cache,
+            padding_mask = padding_mask,
             position_embeddings = position_embeddings,
         )
         hidden_states = residual + hidden_states
@@ -145,38 +160,41 @@ def Qwen3MoeDecoderLayer_fast_forward(
         hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
         hidden_states, router_logits = self.mlp(hidden_states)
         hidden_states = residual + hidden_states
-    pass
 
     outputs = (hidden_states,)
-    if output_attentions: outputs += (self_attn_weights,)
-    if output_router_logits: outputs += (router_logits,)
-    if use_cache: outputs += (present_key_value,)
+    if output_attentions:
+        outputs += (self_attn_weights,)
+    if output_router_logits:
+        outputs += (router_logits,)
+    if use_cache:
+        outputs += (present_key_value,)
     return outputs
 
 
-
 class FastQwen3MoeModel(FastQwen3Model):
-
     @staticmethod
     def pre_patch():
         init_name, function = patch_linear_scaling(
-            model_name         = "Qwen3Moe",
-            rope_module        = LlamaRotaryEmbedding,
+            model_name = "Qwen3Moe",
+            rope_module = LlamaRotaryEmbedding,
             scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
-            attention_module   = Qwen3MoeAttention,
+            attention_module = Qwen3MoeAttention,
         )
         if init_name is not None:
             exec(function, globals())
-            Qwen3MoeAttention.__init__  = eval(init_name)
-        pass
-        Qwen3MoeAttention      .forward = Qwen3Attention_fast_forward
+            Qwen3MoeAttention.__init__ = eval(init_name)
+        Qwen3MoeAttention.forward = Qwen3Attention_fast_forward
         # Qwen3SdpaAttention   .forward = Qwen3Attention_fast_forward
         # Qwen3FlashAttention2 .forward = Qwen3Attention_fast_forward
-        Qwen3MoeSparseMoeBlock .forward = Qwen3MoeSparseMoeBlock_fast_forward
-        Qwen3MoeMLP            .forward = fast_swiglu_inference # This is analogous to Dense models' MLP
-        Qwen3MoeDecoderLayer   .forward = Qwen3MoeDecoderLayer_fast_forward
-        Qwen3MoeModel          .forward = LlamaModel_fast_forward
-        Qwen3MoeForCausalLM    .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
+        Qwen3MoeSparseMoeBlock.forward = Qwen3MoeSparseMoeBlock_fast_forward
+        Qwen3MoeMLP.forward = (
+            fast_swiglu_inference  # This is analogous to Dense models' MLP
+        )
+        Qwen3MoeDecoderLayer.forward = Qwen3MoeDecoderLayer_fast_forward
+        Qwen3MoeModel.forward = LlamaModel_fast_forward
+        Qwen3MoeForCausalLM.forward = CausalLM_fast_forward(
+            LlamaModel_fast_forward_inference
+        )
         PeftModelForCausalLM.forward = PeftModel_fast_forward
         fix_prepare_inputs_for_generation(Qwen3MoeForCausalLM)
 
@@ -186,39 +204,38 @@ def pre_patch():
         # https://github.com/huggingface/transformers/pull/27931
         # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py\
         import transformers.models.qwen3_moe.modeling_qwen3_moe
-        transformers.models.Qwen3Moe.modeling_qwen3_moe.Qwen3MoeRotaryEmbedding = LlamaRotaryEmbedding
-        return
-    pass
 
+        transformers.models.Qwen3Moe.modeling_qwen3_moe.Qwen3MoeRotaryEmbedding = (
+            LlamaRotaryEmbedding
+        )
+        return
 
     @staticmethod
-    def from_pretrained(  #TODO: Change after release
-        model_name        = "Qwen/Qwen3-7B",
-        max_seq_length    = 4096,
-        dtype             = None,
-        load_in_4bit      = True,
-        token             = None,
-        device_map        = "sequential",
-        rope_scaling      = None,
-        fix_tokenizer     = True,
-        model_patcher     = None,
-        tokenizer_name    = None,
+    def from_pretrained(  # TODO: Change after release
+        model_name = "Qwen/Qwen3-7B",
+        max_seq_length = 4096,
+        dtype = None,
+        load_in_4bit = True,
+        token = None,
+        device_map = "sequential",
+        rope_scaling = None,
+        fix_tokenizer = True,
+        model_patcher = None,
+        tokenizer_name = None,
         trust_remote_code = False,
         **kwargs,
     ):
         return FastLlamaModel.from_pretrained(
-            model_name        = model_name,
-            max_seq_length    = max_seq_length,
-            dtype             = dtype,
-            load_in_4bit      = load_in_4bit,
-            token             = token,
-            device_map        = device_map,
-            rope_scaling      = rope_scaling,
-            fix_tokenizer     = fix_tokenizer,
-            model_patcher     = FastQwen3Model,
-            tokenizer_name    = tokenizer_name,
+            model_name = model_name,
+            max_seq_length = max_seq_length,
+            dtype = dtype,
+            load_in_4bit = load_in_4bit,
+            token = token,
+            device_map = device_map,
+            rope_scaling = rope_scaling,
+            fix_tokenizer = fix_tokenizer,
+            model_patcher = FastQwen3Model,
+            tokenizer_name = tokenizer_name,
             trust_remote_code = trust_remote_code,
             **kwargs,
         )
-    pass
-pass
\ No newline at end of file
diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py
index 9ced8f199..dc91ea13f 100644
--- a/unsloth/models/rl.py
+++ b/unsloth/models/rl.py
@@ -35,26 +35,28 @@
 )
 
 torch_compile_options = {
-    "epilogue_fusion"   : True,
-    "max_autotune"      : False, # Disable Triton mm kernels
-    "shape_padding"     : True,
-    "trace.enabled"     : False,
-    "triton.cudagraphs" : False,
+    "epilogue_fusion": True,
+    "max_autotune": False,  # Disable Triton mm kernels
+    "shape_padding": True,
+    "trace.enabled": False,
+    "triton.cudagraphs": False,
 }
 
 from trl import __version__ as trl_version
 from unsloth_zoo.utils import Version
+
 trl_version = Version(trl_version)
 
+
 def vLLMSamplingParams(**kwargs):
     from vllm import SamplingParams
+
     sampling_params = SamplingParams(**kwargs)
     sampling_params._set_kwargs = kwargs
     return sampling_params
-pass
 
-def PatchRL(FastLanguageModel):
 
+def PatchRL(FastLanguageModel):
     from trl.models.utils import unwrap_model_for_generation
     from contextlib import contextmanager
 
@@ -67,12 +69,13 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs):
             # We must use .clone for Unsloth since we force inference_mode
             # Rather we should have used no_grad
             original_generate = unwrapped_model.generate
+
             def generate_with_clone(*args, **kwargs):
                 out = original_generate(*args, **kwargs)
                 if isinstance(out, torch.Tensor):
                     return out.clone()
                 return out
-            pass
+
             unwrapped_model.generate = generate_with_clone
 
             try:
@@ -81,14 +84,18 @@ def generate_with_clone(*args, **kwargs):
                 # Restore generate and return
                 unwrapped_model.generate = original_generate
                 FastLanguageModel.for_training(model)
-            pass
-        pass
-    pass
 
     from transformers import Trainer
     from transformers.trainer_pt_utils import nested_detach
-    @torch.no_grad()    
-    def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_keys,):
+
+    @torch.no_grad()
+    def unsloth_prediction_step(
+        self,
+        model,
+        inputs,
+        prediction_loss_only,
+        ignore_keys,
+    ):
         """
         Perform an evaluation step on `model` using `inputs`.
         Subclass and override to inject custom behavior.
@@ -108,19 +115,27 @@ def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_key
             Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
             logits and labels (each being optional).
         """
-        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
+        has_labels = (
+            False
+            if len(self.label_names) == 0
+            else all(inputs.get(k) is not None for k in self.label_names)
+        )
         # For CLIP-like models capable of returning loss values.
         # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
         # is `True` in `model.forward`.
         return_loss = inputs.get("return_loss", None)
         if return_loss is None:
             return_loss = self.can_return_loss
-        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
+        loss_without_labels = (
+            True if len(self.label_names) == 0 and return_loss else False
+        )
 
         inputs = self._prepare_inputs(inputs)
         if ignore_keys is None:
             if hasattr(self.model, "config"):
-                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
+                ignore_keys = getattr(
+                    self.model.config, "keys_to_ignore_at_inference", []
+                )
             else:
                 ignore_keys = []
 
@@ -131,25 +146,36 @@ def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_key
                 labels = labels[0]
         else:
             labels = None
-            
+
         os.environ["UNSLOTH_RETURN_LOGITS"] = "1"
         with torch.no_grad():
             if has_labels or loss_without_labels:
                 with self.compute_loss_context_manager():
-                    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
+                    loss, outputs = self.compute_loss(
+                        model, inputs, return_outputs = True
+                    )
                 loss = loss.mean().detach()
 
                 if isinstance(outputs, dict):
-                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
+                    logits = tuple(
+                        v for k, v in outputs.items() if k not in ignore_keys + ["loss"]
+                    )
                 else:
                     logits = outputs[1:]
             else:
                 loss = None
                 with self.compute_loss_context_manager():
-                    tokenized_output = self.processing_class(inputs["prompt"], padding=True, truncation=True, return_tensors="pt").to(model.device)
+                    tokenized_output = self.processing_class(
+                        inputs["prompt"],
+                        padding = True,
+                        truncation = True,
+                        return_tensors = "pt",
+                    ).to(model.device)
                     outputs = model(**tokenized_output)
                 if isinstance(outputs, dict):
-                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
+                    logits = tuple(
+                        v for k, v in outputs.items() if k not in ignore_keys
+                    )
                 else:
                     logits = outputs
                 # TODO: this needs to be fixed and made cleaner later.
@@ -164,26 +190,30 @@ def unsloth_prediction_step(self, model, inputs, prediction_loss_only,ignore_key
             logits = logits[0]
 
         return (loss, logits, labels)
+
     import trl.trainer
+
     trainers = dir(trl.trainer)
     trainers = [x for x in trainers if x.endswith("_trainer")]
     unwrap = "unwrap_model_for_generation"
     for trainer in trainers:
-        try: current_trainer = eval(f"trl.trainer.{trainer}")
-        except: continue
+        try:
+            current_trainer = eval(f"trl.trainer.{trainer}")
+        except:
+            continue
         if hasattr(current_trainer, unwrap):
-            try: exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}")
-            except: continue
+            try:
+                exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}")
+            except:
+                continue
     exec(f"Trainer.prediction_step=unsloth_prediction_step")
-    pass
-pass
 
 
-selective_log_softmax            = RL_REPLACEMENTS["selective_log_softmax"]
-calculate_pad_tokens_in_prompt   = RL_REPLACEMENTS["calculate_pad_tokens_in_prompt"]
+selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"]
+calculate_pad_tokens_in_prompt = RL_REPLACEMENTS["calculate_pad_tokens_in_prompt"]
 create_completion_attention_mask = RL_REPLACEMENTS["create_completion_attention_mask"]
-left_pack_padding                = RL_REPLACEMENTS["left_pack_padding"]
-align_logprobs_with_mask         = RL_REPLACEMENTS["align_logprobs_with_mask"]
+left_pack_padding = RL_REPLACEMENTS["left_pack_padding"]
+align_logprobs_with_mask = RL_REPLACEMENTS["align_logprobs_with_mask"]
 
 RLTrainer_replacement = '''
 import os
@@ -282,36 +312,58 @@ def __init__({RLTrainer_arguments},
 pass
 '''
 
+
 def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
     # Patch for vLLM and Unsloth PEFT
     import trl
     import trl.trainer
+
     try:
         trainer = eval(f"trl.trainer.{trainer_file}")
     except Exception as error:
         return
 
     # Get SFTTrainer and SFTConfig names
-    name   = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()]
-    config = [x for x in dir(trainer) if x.endswith("Config")  and x != "Config"  and trainer_file.split("_")[0] in x.lower()]
-    if len(name)   != 1: return
-    if len(config) != 1: return
+    name = [
+        x
+        for x in dir(trainer)
+        if x.endswith("Trainer")
+        and x != "Trainer"
+        and trainer_file.split("_")[0] in x.lower()
+    ]
+    config = [
+        x
+        for x in dir(trainer)
+        if x.endswith("Config")
+        and x != "Config"
+        and trainer_file.split("_")[0] in x.lower()
+    ]
+    if len(name) != 1:
+        return
+    if len(config) != 1:
+        return
 
     # Get SFTTrainer, SFTConfig
     RLTrainer_name = name[0]
-    RLConfig_name  = config[0]
-    try: RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}")
-    except: return
-    try: RLConfig  = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}" )
-    except: return
+    RLConfig_name = config[0]
+    try:
+        RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}")
+    except:
+        return
+    try:
+        RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}")
+    except:
+        return
 
     # Check name
-    if RLTrainer.__name__.startswith("Unsloth"): return
-    if RLConfig .__name__.startswith("Unsloth"): return
+    if RLTrainer.__name__.startswith("Unsloth"):
+        return
+    if RLConfig.__name__.startswith("Unsloth"):
+        return
 
     # Get old source
     old_RLTrainer_source = inspect.getsource(RLTrainer)
-    old_RLConfig_source  = inspect.getsource(RLConfig)
+    old_RLConfig_source = inspect.getsource(RLConfig)
 
     all_imports = dir(trainer)
     # Fix _deprecate_arguments not getting imported so stop __ but not _
@@ -322,23 +374,38 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
     processed = []
     for RLobject in [RLTrainer, RLConfig]:
         parameters = inspect.signature(RLobject.__init__).parameters
-        types = (bool, type(None), int, float, str,)
+        types = (
+            bool,
+            type(None),
+            int,
+            float,
+            str,
+        )
         arguments = ["self"]
         call_args = []
         for k, v in parameters.items():
-            if k == "self": continue
+            if k == "self":
+                continue
             v = v.default
-            if v == "\n": v = re.escape("\n")
-            if v is EMPTY: arguments.append(k)
-            elif type(v) is str:   arguments.append(f"{k} = '{v}'")
-            elif type(v) in types: arguments.append(f"{k} = {v}")
-            else: continue
+            if v == "\n":
+                v = re.escape("\n")
+            if v is EMPTY:
+                arguments.append(k)
+            elif type(v) is str:
+                arguments.append(f"{k} = '{v}'")
+            elif type(v) in types:
+                arguments.append(f"{k} = {v}")
+            else:
+                continue
             call_args.append(f"{k} = {k}")
-        pass
-        arguments = f"\n{' '*8}" + f",\n{' '*8}".join(arguments)
-        call_args = f"\n{' '*12}" + f",\n{' '*12}".join(call_args)
-        processed.append((arguments, call_args,))
-    pass
+        arguments = f"\n{' ' * 8}" + f",\n{' ' * 8}".join(arguments)
+        call_args = f"\n{' ' * 12}" + f",\n{' ' * 12}".join(call_args)
+        processed.append(
+            (
+                arguments,
+                call_args,
+            )
+        )
 
     # Process RLTrainer first
     arguments, call_args = processed[0]
@@ -346,274 +413,269 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
 
     # Add tokenizer if not seen
     if "tokenizer" not in parameters and "processing_class" in parameters:
-        arguments += f",\n{' '*8}tokenizer = None"
+        arguments += f",\n{' ' * 8}tokenizer = None"
         call_args = call_args.replace(
             "processing_class = processing_class",
             "processing_class = tokenizer if tokenizer is not None else processing_class",
         )
-    pass
 
     # Edit bf16, fp16 by checking model's dtype/torch_dtype directly
     extra_args = ""
     if "args" in call_args and "model" in call_args:
-        mixed_precision = \
-        "use_bf16 = getattr(args, 'bf16', False)\n"\
-        "if type(use_bf16) is not bool: use_bf16 = False\n"\
-        "use_fp16 = getattr(args, 'fp16', False)\n"\
-        "if type(use_fp16) is not bool: use_fp16 = False\n"\
-        "force_float32 = False\n"\
-        "full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'\n"\
-        "if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):\n"\
-        "    print('Unsloth: Switching to float32 training since model cannot work with float16')\n"\
-        "    force_float32 = True\n"\
-        "mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')\n"\
-        "dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)\n"\
-        "if dtype is None: dtype = model.get_input_embeddings().dtype\n"\
-        "from unsloth_zoo.utils import _get_dtype\n"\
-        "dtype = _get_dtype(dtype)\n"\
-        "float16 = dtype == torch.float16\n"\
-        "if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\
-        "if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\
-        "if force_float32:\n"\
-        "    # Forced float32 training\n"\
-        "    args.fp16 = False\n"\
-        "    args.bf16 = False\n"\
-        "    os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n"\
-        "elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\n"\
-        "    # Mixed precision training\n"\
-        "    args.fp16 = float16\n"\
-        "    args.bf16 = not float16\n"\
-        "    os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'\n"
-        "elif mixed_precision_dtype == 'bfloat16':\n"\
-        "    # Both False since bfloat16 full finetuning doesn't do any autocasting.\n"\
-        "    args.fp16 = False\n"\
-        "    args.bf16 = False\n"\
-        "    os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n"
+        mixed_precision = (
+            "use_bf16 = getattr(args, 'bf16', False)\n"
+            "if type(use_bf16) is not bool: use_bf16 = False\n"
+            "use_fp16 = getattr(args, 'fp16', False)\n"
+            "if type(use_fp16) is not bool: use_fp16 = False\n"
+            "force_float32 = False\n"
+            "full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'\n"
+            "if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):\n"
+            "    print('Unsloth: Switching to float32 training since model cannot work with float16')\n"
+            "    force_float32 = True\n"
+            "mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')\n"
+            "dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)\n"
+            "if dtype is None: dtype = model.get_input_embeddings().dtype\n"
+            "from unsloth_zoo.utils import _get_dtype\n"
+            "dtype = _get_dtype(dtype)\n"
+            "float16 = dtype == torch.float16\n"
+            "if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"
+            "if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"
+            "if force_float32:\n"
+            "    # Forced float32 training\n"
+            "    args.fp16 = False\n"
+            "    args.bf16 = False\n"
+            "    os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n"
+            "elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\n"
+            "    # Mixed precision training\n"
+            "    args.fp16 = float16\n"
+            "    args.bf16 = not float16\n"
+            "    os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'\n"
+        )
+        "elif mixed_precision_dtype == 'bfloat16':\n    # Both False since bfloat16 full finetuning doesn't do any autocasting.\n    args.fp16 = False\n    args.bf16 = False\n    os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n"
         extra_args += mixed_precision
-    pass
 
     # Check if per_device_eval_batch_size (default 8) bigger than bsz
     # Also use FP16 / BF16 evaluation
     if "args" in call_args:
         # Check eval_dataset first
         if "eval_dataset" in call_args:
-            check_eval_dataset = \
-            "if getattr(args, 'eval_dataset', None) is not None and "\
-            "getattr(args, 'eval_strategy', 'no') == 'no':\n"\
-            "    args.eval_strategy = 'steps'\n"\
-            "    if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1\n"
+            check_eval_dataset = (
+                "if getattr(args, 'eval_dataset', None) is not None and "
+                "getattr(args, 'eval_strategy', 'no') == 'no':\n"
+                "    args.eval_strategy = 'steps'\n"
+                "    if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1\n"
+            )
             extra_args += check_eval_dataset
-        pass
 
         # Check if gradient accumulation bug fix is applied
-        check_ga = \
-        "ga_steps = getattr(args, 'gradient_accumulation_steps', None)\n"\
-        "if ga_steps is not None and ga_steps > 1:\n"\
-        "    from transformers import __version__ as transformers_version\n"\
-        "    if Version(transformers_version) <= Version('4.45.2'):\n"\
-        "        print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\n"\
-        "              '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')\n"
+        check_ga = (
+            "ga_steps = getattr(args, 'gradient_accumulation_steps', None)\n"
+            "if ga_steps is not None and ga_steps > 1:\n"
+            "    from transformers import __version__ as transformers_version\n"
+            "    if Version(transformers_version) <= Version('4.45.2'):\n"
+            "        print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\n"
+            "              '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')\n"
+        )
         extra_args += check_ga
 
-        eval_changes = \
-        "if getattr(args, 'eval_strategy', 'no') != 'no':\n"\
-        "    eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)\n"\
-        "    if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size\n"\
-        "    if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps\n"\
-        "fp16_full_eval = getattr(args, 'fp16_full_eval', False)\n"\
-        "if type(fp16_full_eval) is not bool: fp16_full_eval = False\n"\
-        "bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n"\
-        "if type(bf16_full_eval) is not bool: bf16_full_eval = False\n"\
-        "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"\
-        "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"\
-        "if force_float32:\n"\
-        "    args.bf16_full_eval = False\n"\
-        "    args.fp16_full_eval = False\n"\
-        "elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\n"\
-        "    args.bf16_full_eval = True\n"\
-        "    args.fp16_full_eval = False\n"\
-        "elif not bf16_full_eval and not fp16_full_eval:\n"\
-        "    args.bf16_full_eval = args.bf16\n"\
-        "    args.fp16_full_eval = args.fp16\n"
+        eval_changes = (
+            "if getattr(args, 'eval_strategy', 'no') != 'no':\n"
+            "    eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)\n"
+            "    if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size\n"
+            "    if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps\n"
+            "fp16_full_eval = getattr(args, 'fp16_full_eval', False)\n"
+            "if type(fp16_full_eval) is not bool: fp16_full_eval = False\n"
+            "bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n"
+            "if type(bf16_full_eval) is not bool: bf16_full_eval = False\n"
+            "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"
+            "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"
+            "if force_float32:\n"
+            "    args.bf16_full_eval = False\n"
+            "    args.fp16_full_eval = False\n"
+            "elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\n"
+            "    args.bf16_full_eval = True\n"
+            "    args.fp16_full_eval = False\n"
+            "elif not bf16_full_eval and not fp16_full_eval:\n"
+            "    args.bf16_full_eval = args.bf16\n"
+            "    args.fp16_full_eval = args.fp16\n"
+        )
         extra_args += eval_changes
-    pass
 
     # Force logits to be produced if preprocess_logits_for_metrics or compute_metrics is used
     if "model" in call_args:
-        logits_check = \
-        "_output_logits = False\n"\
-        "if locals().get('compute_metrics', None) is not None: _output_logits = True\n"\
-        "if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True\n"\
-        "if _output_logits:\n"\
-        "    os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"
+        logits_check = (
+            "_output_logits = False\n"
+            "if locals().get('compute_metrics', None) is not None: _output_logits = True\n"
+            "if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True\n"
+            "if _output_logits:\n"
+            "    os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"
+        )
         extra_args += logits_check
-    pass
 
     # Check max_seq_length
     if "model" in call_args:
-        length_check = \
-        "if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):\n"\
-        "    pass\n"\
-        "else:\n"\
-        "    model_max_seq_length = getattr(model, 'max_seq_length', None)\n"\
-        "    args_max_seq_length  = getattr(args,  'max_seq_length', None)\n"\
-        "    if args_max_seq_length is None and model_max_seq_length is not None:\n"\
-        "        max_seq_length = model.max_seq_length\n"\
-        "        if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length\n"
-        "    elif args_max_seq_length is not None and model_max_seq_length is not None:\n"\
-        "        if args_max_seq_length > model_max_seq_length:\n"\
-        "            print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but \n"\
-        "                   the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')\n"\
-        "            args.max_seq_length = model_max_seq_length\n"
+        length_check = (
+            "if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):\n"
+            "    pass\n"
+            "else:\n"
+            "    model_max_seq_length = getattr(model, 'max_seq_length', None)\n"
+            "    args_max_seq_length  = getattr(args,  'max_seq_length', None)\n"
+            "    if args_max_seq_length is None and model_max_seq_length is not None:\n"
+            "        max_seq_length = model.max_seq_length\n"
+            "        if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length\n"
+        )
+        "    elif args_max_seq_length is not None and model_max_seq_length is not None:\n        if args_max_seq_length > model_max_seq_length:\n            print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but \n                   the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')\n            args.max_seq_length = model_max_seq_length\n"
         extra_args += length_check
 
         # At this point max_seq_length might be set, but trl is moving to max_length
         if trainer_file == "sft_trainer":
-            max_length_check = \
-            "if 'max_length' not in locals() and not hasattr(args, 'max_length'):\n"\
-            "    pass\n"\
-            "else:\n"\
-            "    if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0:\n"\
-            "        if hasattr(args, 'max_length'):\n"\
-            "            args.max_length = args.max_seq_length\n"\
-            "            max_length = args.max_length\n"\
-            "    else:\n"\
-            "        model_max_length = getattr(model, 'max_seq_length', None)\n"\
-            "        if model_max_length is None: model_max_length = getattr(model, 'max_length', None)\n"\
-            "        if model_max_length is not None:\n"\
-            "            args.max_length = model_max_length\n"\
-            "            max_length = args.max_length\n"\
-            "        elif hasattr(args, 'max_length') and args.max_length is not None:\n"\
-            "            max_length = args.max_length\n"\
-            "            # if we are here, then we are in a weird case where max_length is set but max_seq_length is not set\n"\
-            "            setattr(model, 'max_seq_length', max_length)\n"\
-            "        else:\n"\
-            "            print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.')\n"\
-            "            args.max_length = 1024\n"
+            max_length_check = (
+                "if 'max_length' not in locals() and not hasattr(args, 'max_length'):\n"
+                "    pass\n"
+                "else:\n"
+                "    if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0:\n"
+                "        if hasattr(args, 'max_length'):\n"
+                "            args.max_length = args.max_seq_length\n"
+                "            max_length = args.max_length\n"
+                "    else:\n"
+                "        model_max_length = getattr(model, 'max_seq_length', None)\n"
+                "        if model_max_length is None: model_max_length = getattr(model, 'max_length', None)\n"
+                "        if model_max_length is not None:\n"
+                "            args.max_length = model_max_length\n"
+                "            max_length = args.max_length\n"
+                "        elif hasattr(args, 'max_length') and args.max_length is not None:\n"
+                "            max_length = args.max_length\n"
+                "            # if we are here, then we are in a weird case where max_length is set but max_seq_length is not set\n"
+                "            setattr(model, 'max_seq_length', max_length)\n"
+                "        else:\n"
+                "            print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.')\n"
+                "            args.max_length = 1024\n"
+            )
             extra_args += max_length_check
-    pass
 
     # Enable for training and move padding side of tokenizer to right
     if "model" in call_args:
-        training_check = \
-        "if model is not None and hasattr(model, 'for_training'):\n"\
-        "    model.for_training()\n"\
-        "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\
-        "if 'processing_class' in locals():\n"\
-        "    if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\
-        "    if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): "\
-        "processing_class.tokenizer.padding_side = 'right'\n"
+        training_check = (
+            "if model is not None and hasattr(model, 'for_training'):\n"
+            "    model.for_training()\n"
+            "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"
+            "if 'processing_class' in locals():\n"
+            "    if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"
+            "    if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): "
+            "processing_class.tokenizer.padding_side = 'right'\n"
+        )
         extra_args += training_check
-    pass
 
     # Check data collator if it's correct!
     if "data_collator" in call_args and "train_dataset" in call_args:
-        data_collator_check = \
-        "__tokenizer = processing_class if 'processing_class' in locals() else tokenizer\n"\
-        "from unsloth_zoo.vision_utils import UnslothVisionDataCollator\n"\
-        "if not isinstance(data_collator, UnslothVisionDataCollator):\n"\
-        "    if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\
-        "        data_collator = TransformersDataCollatorForLanguageModeling(\n"\
-        "            __tokenizer,\n"\
-        "            mlm = False,\n"\
-        "            mlm_probability = 0.0,\n"\
-        "            pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"\
-        "        )\n"\
-        "    elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\
-        "        data_collator = DataCollatorForSeq2Seq(\n"\
-        "            __tokenizer,\n"\
-        "            pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"\
-        "        )\n"\
-        "else:\n"\
-        "    if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False\n"\
-        "    if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''\n"\
-        "    if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}\n"
+        data_collator_check = (
+            "__tokenizer = processing_class if 'processing_class' in locals() else tokenizer\n"
+            "from unsloth_zoo.vision_utils import UnslothVisionDataCollator\n"
+            "if not isinstance(data_collator, UnslothVisionDataCollator):\n"
+            "    if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"
+            "        data_collator = TransformersDataCollatorForLanguageModeling(\n"
+            "            __tokenizer,\n"
+            "            mlm = False,\n"
+            "            mlm_probability = 0.0,\n"
+            "            pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"
+            "        )\n"
+            "    elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"
+            "        data_collator = DataCollatorForSeq2Seq(\n"
+            "            __tokenizer,\n"
+            "            pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"
+            "        )\n"
+            "else:\n"
+            "    if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False\n"
+            "    if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''\n"
+            "    if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}\n"
+        )
         extra_args += data_collator_check
 
         # Also check if .pad exists -> if not, and is VLM, then change it!
-        pad_check = \
-        "if not isinstance(data_collator, UnslothVisionDataCollator):\n"\
-        "    if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):\n"\
-        "        if isinstance(data_collator, DataCollatorForSeq2Seq):\n"\
-        "            data_collator = DataCollatorForSeq2Seq(\n"\
-        "                __tokenizer.tokenizer,\n"\
-        "                pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"\
-        "            )\n"\
-        "        else:\n"\
-        "            data_collator = TransformersDataCollatorForLanguageModeling(\n"\
-        "                __tokenizer.tokenizer,\n"\
-        "                mlm = False,\n"\
-        "                mlm_probability = 0.0,\n"\
-        "                pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"\
-        "            )\n"
+        pad_check = (
+            "if not isinstance(data_collator, UnslothVisionDataCollator):\n"
+            "    if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):\n"
+            "        if isinstance(data_collator, DataCollatorForSeq2Seq):\n"
+            "            data_collator = DataCollatorForSeq2Seq(\n"
+            "                __tokenizer.tokenizer,\n"
+            "                pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"
+            "            )\n"
+            "        else:\n"
+            "            data_collator = TransformersDataCollatorForLanguageModeling(\n"
+            "                __tokenizer.tokenizer,\n"
+            "                mlm = False,\n"
+            "                mlm_probability = 0.0,\n"
+            "                pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"
+            "            )\n"
+        )
         extra_args += pad_check
-    pass
 
     # Check NEFTune
     if "model" in call_args:
-        neftune_check = \
-        "if hasattr(self, 'neftune_hook_handle'):\n"\
-        "    self.neftune_hook_handle.remove()\n"\
-        "    if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"\
-        "if getattr(args, 'neftune_noise_alpha', None) is not None:\n"\
-        "    model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"\
-        "pass\n"
+        neftune_check = (
+            "if hasattr(self, 'neftune_hook_handle'):\n"
+            "    self.neftune_hook_handle.remove()\n"
+            "    if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"
+            "if getattr(args, 'neftune_noise_alpha', None) is not None:\n"
+            "    model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"
+            "pass\n"
+        )
         RLTrainer_post += neftune_check
-    pass
 
     # Add accelerator scaler to model
     if "model" in call_args:
-        accelerator_check = \
-        "if hasattr(self, 'accelerator'):\n"\
-        "    scaler = self.accelerator.scaler\n"\
-        "    current_model = model\n"\
-        "    while hasattr(current_model, 'model'):\n"\
-        "        current_model.accelerator_scaler = scaler\n"\
-        "        current_model = current_model.model\n"\
-        "    current_model.accelerator_scaler = scaler\n"\
-        "pass\n"
+        accelerator_check = (
+            "if hasattr(self, 'accelerator'):\n"
+            "    scaler = self.accelerator.scaler\n"
+            "    current_model = model\n"
+            "    while hasattr(current_model, 'model'):\n"
+            "        current_model.accelerator_scaler = scaler\n"
+            "        current_model = current_model.model\n"
+            "    current_model.accelerator_scaler = scaler\n"
+            "pass\n"
+        )
         RLTrainer_post += accelerator_check
-    pass
 
     # Add enabling and disabling training modes
     if "model" in call_args:
-        training_check = \
-        "if hasattr(self, 'train'):\n"\
-        "    self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)\n"\
-        "pass\n"
+        training_check = (
+            "if hasattr(self, 'train'):\n"
+            "    self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)\n"
+            "pass\n"
+        )
         RLTrainer_post += training_check
-    pass
 
     # Edit optional metrics
     other_metrics_processor = ""
     if trainer_file in RL_METRICS_CHANGES:
         process_extra_args = RL_METRICS_CHANGES[trainer_file]
         for process_extra_arg in process_extra_args:
-            other_metrics_processor += process_extra_arg(old_RLTrainer_source, old_RLConfig_source)
-    pass
+            other_metrics_processor += process_extra_arg(
+                old_RLTrainer_source, old_RLConfig_source
+            )
 
     # Add statistics as well!
-    extra_args += \
-        "other_metrics = []\n"\
-        f"{other_metrics_processor}\n"\
-        "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\
+    extra_args += (
+        "other_metrics = []\n"
+        f"{other_metrics_processor}\n"
+        "from unsloth_zoo.logging_utils import PatchRLStatistics\n"
         f"PatchRLStatistics('{trainer_file}', other_metrics)\n"
+    )
 
     # Patch optional args
     if trainer_file in RL_EXTRA_ARGS:
         process_extra_args = RL_EXTRA_ARGS[trainer_file]
         for process_extra_arg in process_extra_args:
             extra_args += process_extra_arg(call_args, extra_args)
-    pass
 
     # Create RLTrainer args
     extra_args = extra_args.split("\n")
-    extra_args = "\n".join(" "*8 + x for x in extra_args)
+    extra_args = "\n".join(" " * 8 + x for x in extra_args)
     RLTrainer_post = RLTrainer_post.split("\n")
-    RLTrainer_post = "\n".join(" "*8 + x for x in RLTrainer_post)
-    RLTrainer_arguments  = arguments
+    RLTrainer_post = "\n".join(" " * 8 + x for x in RLTrainer_post)
+    RLTrainer_arguments = arguments
     RLTrainer_extra_args = extra_args
-    RLTrainer_call_args  = call_args
+    RLTrainer_call_args = call_args
 
     # Fix RLConfig next
     arguments, call_args = processed[1]
@@ -621,32 +683,32 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
 
     # Edit GA / bsz and weight_decay
     replacements = {
-        "output_dir"                    : None,
-        "logging_nan_inf_filter"        : False,
-        "per_device_train_batch_size"   : 4,
-        "gradient_accumulation_steps"   : 2,
-        "weight_decay"                  : 0.01,
-        "warmup_ratio"                  : 0.1,
-        "seed"                          : 3407,
-        "optim"                         : "adamw_8bit",
-        "learning_rate"                 : 5e-05,
-        "per_device_eval_batch_size"    : 4,
-        "eval_accumulation_steps"       : 2,
-        "torch_empty_cache_steps"       : 250,
-        "logging_steps"                 : 1,
-        "max_seq_length"                : None,
-        "num_generations"               : 8,
+        "output_dir": None,
+        "logging_nan_inf_filter": False,
+        "per_device_train_batch_size": 4,
+        "gradient_accumulation_steps": 2,
+        "weight_decay": 0.01,
+        "warmup_ratio": 0.1,
+        "seed": 3407,
+        "optim": "adamw_8bit",
+        "learning_rate": 5e-05,
+        "per_device_eval_batch_size": 4,
+        "eval_accumulation_steps": 2,
+        "torch_empty_cache_steps": 250,
+        "logging_steps": 1,
+        "max_seq_length": None,
+        "num_generations": 8,
         # "steps_per_generation"          : 1, # Otherwise defaults to ga_steps which is wrong
         # "generation_batch_size"         : None, # Useless. If steps_per_generation set, generation_batch_size clashes
-        "top_k"                         : None,
-        "vllm_mode"                     : "colocate",
-        "generation_kwargs"             : {},
-        "bf16"                          : False,
-        "fp16"                          : False,
-        "include_tokens_per_second"     : False,
-        "include_num_input_tokens_seen" : False,
-        "auto_find_batch_size"          : False, # Auto /2 batch size - too many people complained so removing
-        "dataloader_pin_memory"         : True,
+        "top_k": None,
+        "vllm_mode": "colocate",
+        "generation_kwargs": {},
+        "bf16": False,
+        "fp16": False,
+        "include_tokens_per_second": False,
+        "include_num_input_tokens_seen": False,
+        "auto_find_batch_size": False,  # Auto /2 batch size - too many people complained so removing
+        "dataloader_pin_memory": True,
         # Might fail so disable for now
         # "dataloader_persistent_workers" : True, # Keeps dataloader in RAM
         # "dataloader_prefetch_factor"    : 2,
@@ -657,42 +719,38 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
         y = f"'{v}'" if type(v) is str else f"{v}"
         y = f"{k} = {y},\n"
         arguments = re.sub(x, y, arguments)
-    pass
 
     # Fix GRPO beta default as 0.001 TRL used to be 0.04, now 0.00!
     # https://github.com/huggingface/trl/pull/3516
     # https://verl.readthedocs.io/en/latest/examples/config.html
     if trainer_file == "grpo_trainer":
         replacements = {
-            "loss_type" : "bnpo",           # Default GRPO paper
-            "beta" : 0.001,                 # Recommended as seen in verl
-            "auto_find_batch_size" : False, # Cannot work on GRPO
+            "loss_type": "bnpo",  # Default GRPO paper
+            "beta": 0.001,  # Recommended as seen in verl
+            "auto_find_batch_size": False,  # Cannot work on GRPO
             # [TODO] See https://fengyao.notion.site/off-policy-rl
             # https://github.com/huggingface/trl/pull/3867 (August 7th)
-            "vllm_importance_sampling_correction" : False,
+            "vllm_importance_sampling_correction": False,
         }
         for k, v in replacements.items():
             x = f"{k}( = [^,\n]{{1,}})?,\n"
             y = f"'{v}'" if type(v) is str else f"{v}"
             y = f"{k} = {y},\n"
             arguments = re.sub(x, y, arguments)
-        pass
-    pass
 
     # Warn on too large or too small learning rate
     if "learning_rate" in call_args:
-        learning_rate_check = \
-        "if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! "\
-        "Consider increasing it, otherwise gradient updates will be close to 0!')\n"\
-        "if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! "\
-        "Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n"
+        learning_rate_check = (
+            "if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! "
+            "Consider increasing it, otherwise gradient updates will be close to 0!')\n"
+            "if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! "
+            "Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n"
+        )
         extra_args += learning_rate_check
-    pass
 
     # Check if max_seq_length is NOT defined (max_length is now default)
     if "max_seq_length" not in call_args and "max_length" in call_args:
-        max_seq_length_pre = \
-            """max_seq_length : Optional[int] = field(
+        max_seq_length_pre = """max_seq_length : Optional[int] = field(
         default = None,
         metadata = {'help': 'Maximum sequence length to truncate to.'},
     )"""
@@ -702,165 +760,165 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
         max_seq_length_pre = ""
         max_seq_length_call = ""
         max_seq_length_post = ""
-    pass
 
     # Add output_dir saving
     if "output_dir" in call_args:
         # Default checks
-        saving_check = \
-        "if output_dir is None and save_strategy == 'steps' and save_steps == 500:\n"\
-        "    output_dir = 'unsloth_training_checkpoints'\n"\
-        "    save_strategy = 'no'\n"
+        saving_check = (
+            "if output_dir is None and save_strategy == 'steps' and save_steps == 500:\n"
+            "    output_dir = 'unsloth_training_checkpoints'\n"
+            "    save_strategy = 'no'\n"
+        )
         extra_args += saving_check
-    pass
 
     # Edit dataset_num_proc
     if "dataset_num_proc" in call_args:
-        num_proc_check = \
-        "if dataset_num_proc is None:\n"\
-        "    from multiprocessing import cpu_count\n"\
-        "    dataset_num_proc = min(max(cpu_count()+4, 2), 64)\n"
+        num_proc_check = (
+            "if dataset_num_proc is None:\n"
+            "    from multiprocessing import cpu_count\n"
+            "    dataset_num_proc = min(max(cpu_count()+4, 2), 64)\n"
+        )
         extra_args += num_proc_check
-    pass
 
     # Add padding if flex attention is added
     if "pad_to_multiple_of" in call_args:
-        pad_to_multiple_of = \
-        "if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':\n"\
-        "    from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION\n"\
-        "    if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:\n"\
-        "        from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE\n"\
-        "        pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE\n"\
-        "\n"
+        pad_to_multiple_of = (
+            "if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':\n"
+            "    from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION\n"
+            "    if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:\n"
+            "        from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE\n"
+            "        pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE\n"
+            "\n"
+        )
         extra_args += pad_to_multiple_of
-    pass
 
     # Check for loss_type = dr_grpo and scale_rewards for GRPO
     if "loss_type" in call_args and "scale_rewards" in call_args:
         # See https://github.com/huggingface/trl/issues/3130#issuecomment-2746947835
         # DAPO uses per token loss so BNPO loss used
-        check_dr_grpo = \
-        "if loss_type.lower() == 'dr_grpo':\n"\
-        "    loss_type = 'dr_grpo'\n"\
-        "elif loss_type.lower() == 'dapo':\n"\
-        "    loss_type = 'dapo'\n"\
-        "if loss_type.lower() == 'dr_grpo':\n"\
-        "    if scale_rewards == None:\n"\
-        "        scale_rewards = True\n"\
-        "    elif scale_rewards == True:\n"\
-        "        print('Unsloth: The Dr GRPO paper recommends setting `scale_rewards` to False! Will override. Set it to `None` to force False.')\n"\
-        "        scale_rewards = False\n"\
-        "elif loss_type.lower() == 'dapo':\n"\
-        "    if mask_truncated_completions != True:\n"\
-        "        print('Unsloth: The DAPO paper recommends `mask_truncated_completions = True`')\n"\
-        "    if epsilon_high != 0.28:\n"\
-        "        print('Unsloth: The DAPO paper recommends `epsilon_high = 0.28`')\n"\
-        "    if beta != 0.0:\n"\
-        "        print('Unsloth: The DAPO paper recommends setting `beta = 0.0` to remove the KL term')\n"\
-        "    mask_truncated_completions = True\n"\
-        "    epsilon_high = 0.28\n"\
-        "    beta = 0.0\n"\
-        "\n"
+        check_dr_grpo = (
+            "if loss_type.lower() == 'dr_grpo':\n"
+            "    loss_type = 'dr_grpo'\n"
+            "elif loss_type.lower() == 'dapo':\n"
+            "    loss_type = 'dapo'\n"
+            "if loss_type.lower() == 'dr_grpo':\n"
+            "    if scale_rewards == None:\n"
+            "        scale_rewards = True\n"
+            "    elif scale_rewards == True:\n"
+            "        print('Unsloth: The Dr GRPO paper recommends setting `scale_rewards` to False! Will override. Set it to `None` to force False.')\n"
+            "        scale_rewards = False\n"
+            "elif loss_type.lower() == 'dapo':\n"
+            "    if mask_truncated_completions != True:\n"
+            "        print('Unsloth: The DAPO paper recommends `mask_truncated_completions = True`')\n"
+            "    if epsilon_high != 0.28:\n"
+            "        print('Unsloth: The DAPO paper recommends `epsilon_high = 0.28`')\n"
+            "    if beta != 0.0:\n"
+            "        print('Unsloth: The DAPO paper recommends setting `beta = 0.0` to remove the KL term')\n"
+            "    mask_truncated_completions = True\n"
+            "    epsilon_high = 0.28\n"
+            "    beta = 0.0\n"
+            "\n"
+        )
         extra_args += check_dr_grpo
-    pass
 
     # Check GRPO num_generations mismatch
     if "per_device_train_batch_size" in call_args and "num_generations" in call_args:
-        check_num_generations = \
-        "if (per_device_train_batch_size // num_generations) * num_generations != per_device_train_batch_size:\n"\
-        "    print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\
-                   "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"\
-        "    per_device_train_batch_size = num_generations\n"\
-        "\n"
+        check_num_generations = (
+            "if (per_device_train_batch_size // num_generations) * num_generations != per_device_train_batch_size:\n"
+            "    print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"
+            "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"
+            "    per_device_train_batch_size = num_generations\n"
+            "\n"
+        )
         extra_args += check_num_generations
-    pass
 
     # Check temperature must not be <= 0. Also stop if >= 10
     if "temperature" in call_args:
-        check_temperature = \
-        "if temperature <= 0:\n"\
-        "    raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')\n"\
-        "elif temperature >= 10:\n"\
-        "    raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')\n"\
-        "\n"
+        check_temperature = (
+            "if temperature <= 0:\n"
+            "    raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')\n"
+            "elif temperature >= 10:\n"
+            "    raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')\n"
+            "\n"
+        )
         extra_args += check_temperature
-    pass
 
     # Edit config with anything extra
     if trainer_file in RL_CONFIG_CHANGES:
         process_extra_args = RL_CONFIG_CHANGES[trainer_file]
         for process_extra_arg in process_extra_args:
             extra_args += process_extra_arg(old_RLTrainer_source, old_RLConfig_source)
-    pass
 
     # Edit report_to and default it to nothing if max_steps is like 60
 
     # Create RLConfig args
     extra_args = extra_args.split("\n")
-    extra_args = "\n".join(" "*8 + x for x in extra_args)
-    RLConfig_arguments  = arguments
+    extra_args = "\n".join(" " * 8 + x for x in extra_args)
+    RLConfig_arguments = arguments
     RLConfig_extra_args = extra_args
-    RLConfig_call_args  = call_args
+    RLConfig_call_args = call_args
 
     # Patch vLLM and other functions
-    RLTrainer_extras = patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports)
+    RLTrainer_extras = patch_functions(
+        RLTrainer, trainer_file, RLTrainer_name, all_imports, imports
+    )
     if RLTrainer_extras is None:
         RLTrainer_extras = f"_Unsloth{RLTrainer_name} = {RLTrainer_name}"
 
     # Create full module
     exec(f"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)")
     __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__
-    if __RLTrainer_doc__ is None: __RLTrainer_doc__ = ""
-    __RLConfig_doc__  = eval(f"trl.trainer.{RLConfig_name}") .__doc__
-    if __RLConfig_doc__ is None: __RLConfig_doc__ = ""
+    if __RLTrainer_doc__ is None:
+        __RLTrainer_doc__ = ""
+    __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}").__doc__
+    if __RLConfig_doc__ is None:
+        __RLConfig_doc__ = ""
 
     # Get all pre-modules
     if trainer_file in RL_PRE_ITEMS:
         RL_pre = "\n".join(RL_PRE_ITEMS[trainer_file])
     else:
         RL_pre = ""
-    pass
 
     # Check if SamplingParams is in there
     if "SamplingParams" in old_RLTrainer_source:
         RL_pre = RL_pre + "\n" + inspect.getsource(vLLMSamplingParams)
-    pass
 
     # Selective log softmax and other functions
     selective_log_softmax_code = inspect.getsource(selective_log_softmax)
-    calculate_pad_tokens_in_prompt_code = inspect.getsource(calculate_pad_tokens_in_prompt)
-    create_completion_attention_mask_code = inspect.getsource(create_completion_attention_mask)
+    calculate_pad_tokens_in_prompt_code = inspect.getsource(
+        calculate_pad_tokens_in_prompt
+    )
+    create_completion_attention_mask_code = inspect.getsource(
+        create_completion_attention_mask
+    )
     left_pack_padding_code = inspect.getsource(left_pack_padding)
     align_logprobs_with_mask_code = inspect.getsource(align_logprobs_with_mask)
     # Get final source code
     RLTrainer_source = RLTrainer_replacement.format(
-        RLTrainer_name       = RLTrainer_name,
-        __RLTrainer_doc__    = __RLTrainer_doc__,
-        RLTrainer_arguments  = RLTrainer_arguments,
+        RLTrainer_name = RLTrainer_name,
+        __RLTrainer_doc__ = __RLTrainer_doc__,
+        RLTrainer_arguments = RLTrainer_arguments,
         RLTrainer_extra_args = RLTrainer_extra_args,
-        RLTrainer_call_args  = RLTrainer_call_args,
-        RLTrainer_kwargs     = ",**kwargs"[1 if RLTrainer_call_args.endswith(",") else 0:],
-
-        RLConfig_name        = RLConfig_name,
-        __RLConfig_doc__     = __RLConfig_doc__,
-        RLConfig_arguments   = RLConfig_arguments,
-        RLConfig_extra_args  = RLConfig_extra_args,
-        RLConfig_call_args   = RLConfig_call_args,
-        RLConfig_kwargs      = ",**kwargs"[1 if RLConfig_call_args .endswith(",") else 0:],
-
-        RLTrainer_extras     = RLTrainer_extras,
-        RLTrainer_post       = RLTrainer_post,
-        RL_pre               = RL_pre,
-
-        max_seq_length_pre   = max_seq_length_pre,
-        max_seq_length_call  = max_seq_length_call,
-        max_seq_length_post  = max_seq_length_post,
-
-        selective_log_softmax_code            = selective_log_softmax_code,
-        calculate_pad_tokens_in_prompt_code   = calculate_pad_tokens_in_prompt_code,
+        RLTrainer_call_args = RLTrainer_call_args,
+        RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args.endswith(",") else 0 :],
+        RLConfig_name = RLConfig_name,
+        __RLConfig_doc__ = __RLConfig_doc__,
+        RLConfig_arguments = RLConfig_arguments,
+        RLConfig_extra_args = RLConfig_extra_args,
+        RLConfig_call_args = RLConfig_call_args,
+        RLConfig_kwargs = ",**kwargs"[1 if RLConfig_call_args.endswith(",") else 0 :],
+        RLTrainer_extras = RLTrainer_extras,
+        RLTrainer_post = RLTrainer_post,
+        RL_pre = RL_pre,
+        max_seq_length_pre = max_seq_length_pre,
+        max_seq_length_call = max_seq_length_call,
+        max_seq_length_post = max_seq_length_post,
+        selective_log_softmax_code = selective_log_softmax_code,
+        calculate_pad_tokens_in_prompt_code = calculate_pad_tokens_in_prompt_code,
         create_completion_attention_mask_code = create_completion_attention_mask_code,
-        left_pack_padding_code                = left_pack_padding_code,
-        align_logprobs_with_mask_code         = align_logprobs_with_mask_code,
+        left_pack_padding_code = left_pack_padding_code,
+        align_logprobs_with_mask_code = align_logprobs_with_mask_code,
     )
 
     if RLTrainer_name == "SFTTrainer":
@@ -870,15 +928,15 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
 
         # Temporary patch _is_vlm to False
         # as of 0.22 it only exists in sfttrainer
-        oriignal_is_vlm_text = 'self._is_vlm = True'
-        new_is_vlm_text = 'self._is_vlm = False'
-        RLTrainer_source = RLTrainer_source.replace(oriignal_is_vlm_text, new_is_vlm_text)
-
+        oriignal_is_vlm_text = "self._is_vlm = True"
+        new_is_vlm_text = "self._is_vlm = False"
+        RLTrainer_source = RLTrainer_source.replace(
+            oriignal_is_vlm_text, new_is_vlm_text
+        )
 
     # Remove multiple doc strings
     if __RLConfig_doc__ != "" and RLTrainer_source.count(__RLTrainer_doc__) == 2:
         RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, "", 1)
-    pass
 
     # Remove multiple newlines
     RLTrainer_source = re.sub(r"[\n]{3,}", "\n", RLTrainer_source)
@@ -893,15 +951,38 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
     )
 
     # Patch Trainer
-    exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals())
-    exec(f"trl.trainer.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals())
-    exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals())
+    exec(
+        f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}",
+        locals(),
+        globals(),
+    )
+    exec(
+        f"trl.trainer.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}",
+        locals(),
+        globals(),
+    )
+    exec(
+        f"trl.trainer.{trainer_file}.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}",
+        locals(),
+        globals(),
+    )
 
     # Patch Config
-    exec(f"trl.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals())
-    exec(f"trl.trainer.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals())
-    exec(f"trl.trainer.{trainer_file}.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals())
-pass
+    exec(
+        f"trl.{RLConfig_name} = created_module.Unsloth{RLConfig_name}",
+        locals(),
+        globals(),
+    )
+    exec(
+        f"trl.trainer.{RLConfig_name} = created_module.Unsloth{RLConfig_name}",
+        locals(),
+        globals(),
+    )
+    exec(
+        f"trl.trainer.{trainer_file}.{RLConfig_name} = created_module.Unsloth{RLConfig_name}",
+        locals(),
+        globals(),
+    )
 
 
 def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports):
@@ -917,7 +998,6 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
             bracketed_comment,
             bracketed_comment.replace("(", "[").replace(")", "]"),
         )
-    pass
 
     # Remove peft_config
     init = init.replace("elif peft_config is None:", "elif False:")
@@ -926,9 +1006,14 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
     init = init.replace("if peft_config is not None:", "if False:")
     init = init.replace("get_peft_model(model, peft_config)", "model")
     # New TRL 0.20.0
-    init = init.replace("if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):", "if False:")
+    init = init.replace(
+        "if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):",
+        "if False:",
+    )
     # New TRL 0.20.0
-    init = init.replace("model = self._prepare_peft_model(model, peft_config, args)\n", "pass\n")
+    init = init.replace(
+        "model = self._prepare_peft_model(model, peft_config, args)\n", "pass\n"
+    )
 
     # Set use_vllm if not set
     if "args.use_vllm" in init and "model" in init and "args" in init:
@@ -940,43 +1025,45 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
         )
         if len(replacer) != 0:
             replacer = replacer[0]
-            vllm_setter = "\n" + " "*8 + \
-            "if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'):\n" + \
-            " " * 12 + "if (getattr(args, 'use_vllm', False) == False):\n" + \
-            " " * 16 + "args.use_vllm = True\n"
-            #" " * 16 + "args.vllm_importance_sampling_correction = True\n" + \
-            #" " * 16 + "args.vllm_importance_sampling_cap = 2.0\n"
+            vllm_setter = (
+                "\n"
+                + " " * 8
+                + "if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'):\n"
+                + " " * 12
+                + "if (getattr(args, 'use_vllm', False) == False):\n"
+                + " " * 16
+                + "args.use_vllm = True\n"
+            )
+            # " " * 16 + "args.vllm_importance_sampling_correction = True\n" + \
+            # " " * 16 + "args.vllm_importance_sampling_cap = 2.0\n"
 
             if "grpo" in trainer_file and trl_version >= Version("0.18.0"):
                 # If model has vllm_engine, then use vllm in colocate mode. Donot wait for server
-                vllm_setter += \
-                " " * 12 + "args.vllm_mode='colocate'\n"               
-    
+                vllm_setter += " " * 12 + "args.vllm_mode='colocate'\n"
+
             init = init.replace(replacer, replacer + vllm_setter)
-        pass
-    pass
 
-    #breakpoint()
+    # breakpoint()
 
     vllm_part = re.findall(
-        r"(\n[\s]{8}"\
-        r"if (self|args)\.use_vllm\:.*?"\
-        r"\n[\s]{8}"\
-        "else:\n)",
+        r"(\n[\s]{8}" r"if (self|args)\.use_vllm\:.*?" r"\n[\s]{8}" "else:\n)",
         init,
         flags = re.MULTILINE | re.DOTALL,
     )
-    
+
     if len(vllm_part) == 1:
         vllm_part, args = vllm_part[0][0], vllm_part[0][1]
         # Strip all comments
-        new_vllm_part = re.sub(r"^\s*\#[^\n]*\n?", "", vllm_part, flags=re.MULTILINE) # to also remove whole comment line instead of just starting at #
-        new_vllm_part = re.sub(r"\s*\#.*$", "", new_vllm_part, flags=re.MULTILINE) # remove comments that occur after code
+        new_vllm_part = re.sub(
+            r"^\s*\#[^\n]*\n?", "", vllm_part, flags = re.MULTILINE
+        )  # to also remove whole comment line instead of just starting at #
+        new_vllm_part = re.sub(
+            r"\s*\#.*$", "", new_vllm_part, flags = re.MULTILINE
+        )  # remove comments that occur after code
 
         # Get SamplingParams
         sampling_params = re.findall(
-            r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}"\
-            r"SamplingParams\(.+?\))",
+            r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}" r"SamplingParams\(.+?\))",
             new_vllm_part,
             flags = re.MULTILINE | re.DOTALL,
         )
@@ -986,35 +1073,46 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
             # Fix guided_decoding
             sampling_params = sampling_params.replace(
                 "guided_decoding=guided_decoding,",
-                'guided_decoding='\
-                'GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) '\
+                "guided_decoding="
+                'GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) '
                 'if getattr(args, "vllm_guided_decoding_regex", None) is not None else None,',
             )
             # Replace with our vLLM engine
-            sampling_params = \
-                " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \
-                sampling_params # Add spaces
+            sampling_params = (
+                " " * 12
+                + "self.llm = model.vllm_engine; self._last_loaded_step = 0; "
+                + sampling_params
+            )  # Add spaces
 
             # count the indentation of last line of sampling_params.
             splitted_sampling_params = sampling_params.split("\n")
             if len(splitted_sampling_params) >= 2:
                 last_line = splitted_sampling_params[-1]
                 last_prev_line = splitted_sampling_params[-2]
-                last_prev_indentation = len(last_prev_line) - len(last_prev_line.lstrip())
+                last_prev_indentation = len(last_prev_line) - len(
+                    last_prev_line.lstrip()
+                )
                 last_indentation = len(last_line) - len(last_line.lstrip())
 
                 # Add extra arguments to SamplingParams
                 extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {})"
                 # Backwards replace
-                to_replace = ",\n" + " "*last_prev_indentation + extra + ",\n" + " "*last_indentation + ")"
+                to_replace = (
+                    ",\n"
+                    + " " * last_prev_indentation
+                    + extra
+                    + ",\n"
+                    + " " * last_indentation
+                    + ")"
+                )
                 sampling_params = to_replace.join(sampling_params.rsplit(")", 1))
                 # Strip multiple commas
                 sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params)
 
-                new_vllm_part = \
-                    f"\n{' '*8}if {args}.use_vllm:\n{sampling_params}"\
-                    f"\n{' '*8}else:\n"
-        pass
+                new_vllm_part = (
+                    f"\n{' ' * 8}if {args}.use_vllm:\n{sampling_params}"
+                    f"\n{' ' * 8}else:\n"
+                )
 
         if trl_version >= Version("0.18.0"):
             # Replace LLM init with already existing vLLM engine for colocate mode
@@ -1024,31 +1122,37 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
                 vllm_llm_init_pattern,
                 vllm_llm_replacement,
                 new_vllm_part,
-                flags=re.DOTALL  # Ensure . matches newlines [[5]]
+                flags = re.DOTALL,  # Ensure . matches newlines [[5]]
             )
 
         init = init.replace(vllm_part, new_vllm_part)
-    pass
 
     # Search for vLLM calling in all child functions
     functions = dir(RLTrainer)
     RLTrainer_source = inspect.getsource(RLTrainer)
     functions = [x for x in functions if f"def {x}" in RLTrainer_source]
 
-    changed = {"__init__" : (old_init, init,)}
+    changed = {
+        "__init__": (
+            old_init,
+            init,
+        )
+    }
     edit_functions = RL_FUNCTIONS.get(trainer_file, [])
 
     for function in functions:
-        if not hasattr(RLTrainer, function): continue
+        if not hasattr(RLTrainer, function):
+            continue
         fx = getattr(RLTrainer, function)
-        try: source = inspect.getsource(fx)
-        except: continue
+        try:
+            source = inspect.getsource(fx)
+        except:
+            continue
         original_source = source
 
         # Check for function
         for edit_function in edit_functions:
             source = edit_function(function, source)
-        pass
 
         """
         import torch
@@ -1088,8 +1192,10 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
         lora_name = trainer_file + "_lora_model"
         source = re.sub(
             r"(self\.llm\.(?:generate|chat)\([^\)]{1,})\)",
-            r"\1, lora_request = self.model.load_lora('" + lora_name + r"', load_tensors = True))",
-            source
+            r"\1, lora_request = self.model.load_lora('"
+            + lora_name
+            + r"', load_tensors = True))",
+            source,
         )
         # Prefer using unsloth's sampling params and fallback to trl's if not found
         # We'll enable this later separately when combining both this and GRPOConfig params
@@ -1100,13 +1206,16 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
         # )
 
         # Skip if no changes done
-        if source == original_source: continue
+        if source == original_source:
+            continue
 
         # Find all imports
         imports += [x for x in all_imports if not x.startswith("_") and x in source]
 
-        changed[function] = (original_source, source,)
-    pass
+        changed[function] = (
+            original_source,
+            source,
+        )
 
     # Import all functions
     imports = list(set(imports))
@@ -1115,29 +1224,27 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
     for function in changed:
         old, new = changed[function]
         RLTrainer_source = RLTrainer_source.replace(old, new)
-    pass
 
     RLTrainer_source = RLTrainer_source.replace(
         f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1
     )
     return RLTrainer_source
-pass
 
 
 def patch_trl_rl_trainers():
     # Patch all TRL modules if they have vLLM or PEFT
     import trl.trainer
+
     all_trainers = dir(trl.trainer)
     all_trainers = [x for x in all_trainers if x.islower() and x.endswith("_trainer")]
     for trainer in all_trainers:
         _patch_trl_rl_trainers(trainer)
     return
-pass
 
 
 def PatchFastRL(algorithm = None, FastLanguageModel = None):
-    if FastLanguageModel is not None: PatchRL(FastLanguageModel)
+    if FastLanguageModel is not None:
+        PatchRL(FastLanguageModel)
     patch_trl_rl_trainers()
     if type(algorithm) is str and algorithm.islower():
         PatchRLStatistics(algorithm)
-pass
diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py
index 498de22bb..47539c66f 100644
--- a/unsloth/models/rl_replacements.py
+++ b/unsloth/models/rl_replacements.py
@@ -36,58 +36,66 @@
 )
 import textwrap
 
-RL_EXTRA_ARGS      = defaultdict(list)
-RL_FUNCTIONS       = defaultdict(list)
-RL_PRE_ITEMS       = defaultdict(list)
-RL_CONFIG_CHANGES  = defaultdict(list)
+RL_EXTRA_ARGS = defaultdict(list)
+RL_FUNCTIONS = defaultdict(list)
+RL_PRE_ITEMS = defaultdict(list)
+RL_CONFIG_CHANGES = defaultdict(list)
 RL_METRICS_CHANGES = defaultdict(list)
 
 torch_compile_options = {
-    "epilogue_fusion"   : True,
-    "max_autotune"      : True,
-    "shape_padding"     : True,
-    "trace.enabled"     : False,
-    "triton.cudagraphs" : False,
+    "epilogue_fusion": True,
+    "max_autotune": True,
+    "shape_padding": True,
+    "trace.enabled": False,
+    "triton.cudagraphs": False,
 }
 
+
 # Check untrained tokens
 def sft_trainer_fix_untrained_tokens(call_args, extra_args):
     if "model" in call_args and "train_dataset" in call_args:
-        fix_tokenizer = \
-        "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\\n')\n"\
-        "from unsloth_zoo.tokenizer_utils import fix_untrained_tokens\n"\
-        "from unsloth_zoo.training_utils  import fix_zero_training_loss\n"\
-        "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\
-        "fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n"\
-        "fix_zero_training_loss(model, tokenizer, train_dataset)\n"
+        fix_tokenizer = (
+            "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\\n')\n"
+            "from unsloth_zoo.tokenizer_utils import fix_untrained_tokens\n"
+            "from unsloth_zoo.training_utils  import fix_zero_training_loss\n"
+            "if 'tokenizer' not in locals(): tokenizer = processing_class\n"
+            "fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n"
+            "fix_zero_training_loss(model, tokenizer, train_dataset)\n"
+        )
         return fix_tokenizer
     return ""
-pass
+
+
 RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untrained_tokens)
 
 
 # Remove DPO columns which might randomnly be tokenized
 def dpo_trainer_fix_columns(call_args, extra_args):
     if "model" in call_args and "train_dataset" in call_args:
-        fix_dpo = \
-        "if hasattr(train_dataset, 'column_names'):\n"\
-        "    column_names = set(train_dataset.column_names)\n"\
-        "    check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\
-        "             'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\
-        "             'prompt_input_ids', 'prompt_attention_mask']\n"\
-        "    if all(x in column_names for x in check):\n"\
-        "        train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\
-        "    del check, column_names\n"
+        fix_dpo = (
+            "if hasattr(train_dataset, 'column_names'):\n"
+            "    column_names = set(train_dataset.column_names)\n"
+            "    check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"
+            "             'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"
+            "             'prompt_input_ids', 'prompt_attention_mask']\n"
+            "    if all(x in column_names for x in check):\n"
+            "        train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"
+            "    del check, column_names\n"
+        )
         return fix_dpo
     return ""
-pass
+
+
 RL_EXTRA_ARGS["dpo_trainer"].append(dpo_trainer_fix_columns)
 
 
 # Fix tokenizer double BOS
 def sft_trainer_prepare_dataset(function_name, function):
-    if  function_name != "_prepare_non_packed_dataloader" and \
-        function_name != "_prepare_dataset": return function
+    if (
+        function_name != "_prepare_non_packed_dataloader"
+        and function_name != "_prepare_dataset"
+    ):
+        return function
 
     fast_sft_prepare_dataset = RL_REPLACEMENTS.get("sft_prepare_dataset", None)
     if fast_sft_prepare_dataset is not None:
@@ -102,35 +110,36 @@ def sft_trainer_prepare_dataset(function_name, function):
             # Use fast version!
             function = inspect.getsource(fast_sft_prepare_dataset)
             function = function.split("\n")
-            function = "\n".join(" "*4 + x for x in function)
-            function = function.replace("def sft_prepare_dataset", "def _prepare_dataset")
+            function = "\n".join(" " * 4 + x for x in function)
+            function = function.replace(
+                "def sft_prepare_dataset", "def _prepare_dataset"
+            )
             return function
-        pass
-    pass
-
-    check_text = \
-    "if 'skip_prepare_dataset' in locals() and skip_prepare_dataset:\n"\
-    "    return dataset\n"\
-    "if 'tokenizer'          not in locals(): tokenizer = processing_class\n"\
-    "if 'formatting_func'    not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\
-    "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"\
-    "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"\
-    "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"\
-    "chat_template = getattr(tokenizer, 'chat_template', None)\n"\
-    "chat_template = '' if chat_template is None else chat_template\n"\
-    "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\
-    "if getattr(tokenizer, 'bos_token', None) is not None else False\n"\
-    "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"\
-    "    from functools import partial\n"\
-    "    tokenizer_call = tokenizer.__call__\n"\
-    "    tokenizer.__call__ = partial(tokenizer_call, add_special_tokens = False)\n"\
-    "    processing_class = tokenizer\n"\
-    "else:\n"\
-    "    tokenizer_call = None\n"\
-    "    add_special_tokens = False if has_bos_token_already else locals().get('add_special_tokens', False)\n"
+
+    check_text = (
+        "if 'skip_prepare_dataset' in locals() and skip_prepare_dataset:\n"
+        "    return dataset\n"
+        "if 'tokenizer'          not in locals(): tokenizer = processing_class\n"
+        "if 'formatting_func'    not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"
+        "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"
+        "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"
+        "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"
+        "chat_template = getattr(tokenizer, 'chat_template', None)\n"
+        "chat_template = '' if chat_template is None else chat_template\n"
+        "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "
+        "if getattr(tokenizer, 'bos_token', None) is not None else False\n"
+        "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"
+        "    from functools import partial\n"
+        "    tokenizer_call = tokenizer.__call__\n"
+        "    tokenizer.__call__ = partial(tokenizer_call, add_special_tokens = False)\n"
+        "    processing_class = tokenizer\n"
+        "else:\n"
+        "    tokenizer_call = None\n"
+        "    add_special_tokens = False if has_bos_token_already else locals().get('add_special_tokens', False)\n"
+    )
 
     check_text = check_text.split("\n")
-    check_text = "\n".join(" "*8 + x for x in check_text)
+    check_text = "\n".join(" " * 8 + x for x in check_text)
     check_text = check_text.rstrip() + "\n"
 
     # .*? matches first match. .+? matches final match.
@@ -142,26 +151,31 @@ def sft_trainer_prepare_dataset(function_name, function):
     if len(replacer) != 0:
         replacer = replacer[0]
         function = function.replace(replacer, replacer + check_text)
-    pass
 
     # Return tokenizer's original state
-    return_state = "if tokenizer_call is not None: tokenizer.__call__ = tokenizer_call\n"
+    return_state = (
+        "if tokenizer_call is not None: tokenizer.__call__ = tokenizer_call\n"
+    )
     function = re.sub(
         r"\n([ ]{4,})(return .*?[\s]{0,})$",
         rf"\1{return_state}\1\2",
         function,
     )
     return function
-pass
+
+
 RL_FUNCTIONS["sft_trainer"].append(sft_trainer_prepare_dataset)
 
 
 # Ignore mean_token_accuracy since it needs logits
 # We override it directly with our version
 def sft_trainer_compute_loss(function_name, function):
-    if  function_name != "compute_loss": return function
+    if function_name != "compute_loss":
+        return function
 
-    def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
+    def compute_loss(
+        self, model, inputs, return_outputs = False, num_items_in_batch = None
+    ):
         outputs = super().compute_loss(
             model,
             inputs,
@@ -169,25 +183,26 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
             num_items_in_batch = num_items_in_batch,
         )
         return outputs
-    pass
 
     function = inspect.getsource(compute_loss)
     return function
-pass
+
+
 RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss)
 
 
 # Autocast precision for GRPO
 def grpo_trainer__prepare_inputs(function_name, function):
-    if  function_name != "_prepare_inputs": return function
+    if function_name != "_prepare_inputs":
+        return function
 
     # Add mixed precision training
     function = function.replace(
         "with torch.inference_mode():",
-        "with torch.inference_mode(), "\
-        "torch.amp.autocast(device_type = 'cuda', "\
-        "dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\
-        "if not torch.is_autocast_enabled('cuda') else nullcontext())"\
+        "with torch.inference_mode(), "
+        "torch.amp.autocast(device_type = 'cuda', "
+        "dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "
+        "if not torch.is_autocast_enabled('cuda') else nullcontext())"
         "if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):",
     )
     function = function.replace(
@@ -195,13 +210,15 @@ def grpo_trainer__prepare_inputs(function_name, function):
         "self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False)",
     )
     return function
-pass
+
+
 RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs)
 
 
 # Fix incorrect special tokens handling and truncation in older TRL versions
 def grpo_trainer__generate_and_score_completions(function_name, function):
-    if  function_name != "_generate_and_score_completions": return function
+    if function_name != "_generate_and_score_completions":
+        return function
 
     # TRL 0.19.0 did skip_special_tokens = True which should be False
     function = function.replace(
@@ -210,7 +227,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
     )
 
     # Left pad prompt before calculation old and ref hidden states
-    line_to_replace = "batch_size = self.args.per_device_train_batch_size if mode == \"train\" else self.args.per_device_eval_batch_size"
+    line_to_replace = 'batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size'
 
     # The new multi-line string that will replace the line above
     replacement_lines = """
@@ -233,7 +250,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
         r"^\s*if self\.args\.gradient_accumulation_steps % generate_every != 0 or \(\s*"
         r"self\.use_vllm and self\.vllm_importance_sampling_correction\s*"
         r"\):",
-        re.MULTILINE
+        re.MULTILINE,
     )
 
     replacement_text = """        
@@ -245,26 +262,26 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
 
     pattern_to_find = re.compile(
         r"(^\s*)all_logprobs = \["  # Capture indentation (group 1)
-        r".*?"                      # Match everything inside non-greedily
+        r".*?"  # Match everything inside non-greedily
         r"for output in outputs\.outputs\s*"
         r"\]",
-        re.DOTALL | re.MULTILINE
+        re.DOTALL | re.MULTILINE,
     )
 
     replacement_text = (
-        r'\1from trl.scripts.vllm_serve import sanitize_logprob\n'
-        r'\1all_logprobs = [\n'
-        r'\1    [sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]\n'
-        r'\1    for outputs in all_outputs\n'
-        r'\1    for output in outputs.outputs\n'
-        r'\1]'
+        r"\1from trl.scripts.vllm_serve import sanitize_logprob\n"
+        r"\1all_logprobs = [\n"
+        r"\1    [sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]\n"
+        r"\1    for outputs in all_outputs\n"
+        r"\1    for output in outputs.outputs\n"
+        r"\1]"
     )
 
     function, num_replacements = pattern_to_find.subn(replacement_text, function)
-    
+
     # Always between max_prompt_length and use_vllm
     found = re.findall(
-        r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?"\
+        r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?"
         r"\2if self\.use_vllm:)",
         function,
         flags = re.DOTALL | re.MULTILINE,
@@ -273,10 +290,11 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
         replace_part, spacing = found[0]
         removed_comments = re.sub(r"\#[^\n]{1,}", "", replace_part)
         splits = removed_comments.split("\n")
-        if sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) >= 8:
-
-            new_replacement = \
-            f"""\n{spacing}if self.max_prompt_length is not None:
+        if (
+            sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits) == 2
+            and len(spacing) >= 8
+        ):
+            new_replacement = f"""\n{spacing}if self.max_prompt_length is not None:
             # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
             # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
             # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
@@ -313,20 +331,20 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
                 output["sampling_per_token_logps"] = None"""
 
     function = function.replace(string_to_find, replacement_string)
-    
-    if 'wake_up()' not in function:
+
+    if "wake_up()" not in function:
         # Sleep functionality has been added to trl in v0.23.0. We do not want to redo this.
         # https://github.com/huggingface/trl/commit/edbe8234bc7e528f72ac76607de9d3e4753e2709
 
-        pattern = re.compile(r'.*self\.llm\.generate\(.*\).*', re.MULTILINE)
+        pattern = re.compile(r".*self\.llm\.generate\(.*\).*", re.MULTILINE)
         matches = list(pattern.finditer(function))
         patched = function
 
         # Generally there's only one match. But this is just to make sure we don't miss any.
         for match in reversed(matches):
             line = match.group(0)
-            indent_match = re.match(r'(\s*)', line)
-            indent = indent_match.group(1) if indent_match else ''
+            indent_match = re.match(r"(\s*)", line)
+            indent = indent_match.group(1) if indent_match else ""
 
             wrapped = (
                 f"{indent}if hasattr(self, 'llm'):\n"
@@ -338,18 +356,21 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
                 f"{indent}        self.llm.sleep(os.environ.get('VLLM_SLEEP_MODE', 1))\n"
             )
 
-            patched = patched[:match.start()] + wrapped + patched[match.end():]
+            patched = patched[: match.start()] + wrapped + patched[match.end() :]
 
         function = patched
 
     return function
-pass
+
+
 RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_and_score_completions)
 
+
 # Fix {"reasoning_effort" : "high"} not applied
 def grpo_trainer_fix_maybe_apply_chat_template(function_name, function):
     spaces = function.find("def ")
-    if spaces % 4 != 0: return function
+    if spaces % 4 != 0:
+        return function
     spaces += 4
     replacement = """
         _chat_template_ = getattr(self.processing_class, "chat_template", None)
@@ -371,51 +392,66 @@ def grpo_trainer_fix_maybe_apply_chat_template(function_name, function):
             prompts_text.append(_x_)
     """
     replacement = textwrap.dedent(replacement).strip()
-    replacement = textwrap.indent(replacement, spaces*" ")
+    replacement = textwrap.indent(replacement, spaces * " ")
     replacement = f"\n{replacement}\n"
     what = 'prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]'
-    function = function.replace(what, replacement.replace("__INPUTS__REPLACEMENT__", "inputs"))
+    function = function.replace(
+        what, replacement.replace("__INPUTS__REPLACEMENT__", "inputs")
+    )
 
     """prompts_text = [
         maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
     ]"""
     function = re.sub(
-        r"prompts_text = \["\
-        r"[\s]{0,}"\
-        r"maybe_apply_chat_template\(\{[\"\']prompt[\"\'][\s]{0,}\:[\s]{0,}prompt[\s]{0,}\}[\s]{0,}\,[\s]{0,}self\.processing_class\)"\
-        r"\[[\"\']prompt[\"\']\] for prompt in prompts"\
-        r"[\s]{0,}"\
+        r"prompts_text = \["
+        r"[\s]{0,}"
+        r"maybe_apply_chat_template\(\{[\"\']prompt[\"\'][\s]{0,}\:[\s]{0,}prompt[\s]{0,}\}[\s]{0,}\,[\s]{0,}self\.processing_class\)"
+        r"\[[\"\']prompt[\"\']\] for prompt in prompts"
+        r"[\s]{0,}"
         r"\]",
         replacement.replace("__INPUTS__REPLACEMENT__", "prompts"),
         function,
     )
     return function
-pass
+
+
 RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_fix_maybe_apply_chat_template)
 
+
 # Remove _move_model_to_vllm
 def grpo_trainer__move_model_to_vllm(function_name, function):
-    if  function_name != "_move_model_to_vllm": return function
+    if function_name != "_move_model_to_vllm":
+        return function
 
-    def _move_model_to_vllm(self, *args, **kwargs): return None
+    def _move_model_to_vllm(self, *args, **kwargs):
+        return None
 
     function = inspect.getsource(_move_model_to_vllm)
     return function
-pass
+
+
 RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm)
 
 
 # Edit _get_per_token_logps to handle mixed precision
 def grpo_trainer__get_per_token_logps(function_name, function):
-    if function_name != "_get_per_token_logps": return function
-
-    def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, compute_efficient = False):
-        if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
-            return None # Unsloth efficient GRPO
+    if function_name != "_get_per_token_logps":
+        return function
+
+    def _get_per_token_logps(
+        self, model, input_ids, attention_mask, logits_to_keep, compute_efficient = False
+    ):
+        if True:  # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
+            return None  # Unsloth efficient GRPO
         # Otherwise, calculate normally:
-        if not hasattr(self, '_autocast_dtype'):
-            self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
-            if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
+        if not hasattr(self, "_autocast_dtype"):
+            self._autocast_dtype = (
+                torch.float16
+                if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
+                else torch.bfloat16
+            )
+            if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
+                self._autocast_dtype = torch.float16
 
         os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
         with torch.amp.autocast(device_type = DEVICE_TYPE, dtype = self._autocast_dtype):
@@ -443,41 +479,65 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep,
             #     breakpoint()  # Breakpoint triggered here
             #     print("Found high values!")
             # return  logps #  compute logprobs for the input tokens
-        pass
-    pass
 
     function = inspect.getsource(_get_per_token_logps)
     return function
-pass
+
+
 RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps)
 
+
 def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
-    if function_name != "_get_per_token_logps_and_entropies": return function
+    if function_name != "_get_per_token_logps_and_entropies":
+        return function
 
     # Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway
-    def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, logits_to_keep, batch_size = None,
-                                           compute_entropy = False, compute_efficient = False, *args, **kwargs):
+    def _get_per_token_logps_and_entropies(
+        self,
+        model,
+        input_ids,
+        attention_mask,
+        logits_to_keep,
+        batch_size = None,
+        compute_entropy = False,
+        compute_efficient = False,
+        *args,
+        **kwargs,
+    ):
         # if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
         #     return None, None  # logps, entropies Unsloth efficient GRPO
         if compute_efficient:
             return None, None
         else:
             # Otherwise, calculate normally:
-            if not hasattr(self, '_autocast_dtype'):
-                self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
-                if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
+            if not hasattr(self, "_autocast_dtype"):
+                self._autocast_dtype = (
+                    torch.float16
+                    if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
+                    else torch.bfloat16
+                )
+                if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
+                    self._autocast_dtype = torch.float16
 
-            pixel_values, image_grid_thw = kwargs.get("pixel_values", None), kwargs.get("image_grid_thw", None)
-            pixel_attention_mask, image_sizes = kwargs.get('pixel_attention_mask',None), kwargs.get('image_sizes',None)
+            pixel_values, image_grid_thw = (
+                kwargs.get("pixel_values", None),
+                kwargs.get("image_grid_thw", None),
+            )
+            pixel_attention_mask, image_sizes = (
+                kwargs.get("pixel_attention_mask", None),
+                kwargs.get("image_sizes", None),
+            )
 
             os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
 
-            unwrapped_model = self.accelerator.unwrap_model(model, keep_fp32_wrapper=False)
+            unwrapped_model = self.accelerator.unwrap_model(
+                model, keep_fp32_wrapper = False
+            )
 
-            with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
+            with torch.amp.autocast(device_type = "cuda", dtype = self._autocast_dtype):
                 with torch.inference_mode():
                     if pixel_values is None:
-                        attention_mask =  input_ids != self.processing_class.pad_token_id
+                        attention_mask = input_ids != self.processing_class.pad_token_id
                         attention_mask = attention_mask.to(attention_mask.dtype)
                         # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
                         logits = unwrapped_model(
@@ -487,7 +547,7 @@ def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, l
                             image_grid_thw = image_grid_thw,
                             pixel_attention_mask = pixel_attention_mask,
                             image_sizes = image_sizes,
-                            #logits_to_keep = logits_to_keep + 1,
+                            # logits_to_keep = logits_to_keep + 1,
                         ).logits
                     else:
                         logits = unwrapped_model(
@@ -500,12 +560,11 @@ def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, l
                             logits_to_keep = logits_to_keep + 1,
                         ).logits
 
-
                 entropies = None
                 if compute_entropy:
                     from trl.trainer.utils import entropy_from_logits
-                    entropies = entropy_from_logits(logits)
 
+                    entropies = entropy_from_logits(logits)
 
             os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
             # logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
@@ -526,57 +585,88 @@ def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, l
             #     breakpoint()  # Breakpoint triggered here
             #     print("Found high values!")
             # return  logps #  compute logprobs for the input tokens
-        pass
-    pass
 
     function = inspect.getsource(_get_per_token_logps_and_entropies)
     return function
-pass
+
+
 RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps_and_entropies)
 
-grpo_compute_loss      = RL_REPLACEMENTS["grpo_compute_loss"]
+grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"]
 grpo_compute_loss_slow = RL_REPLACEMENTS["grpo_compute_loss_slow"]
-UnslothEfficientGRPO   = RL_REPLACEMENTS["UnslothEfficientGRPO"]
-grpo_accumulated_loss  = RL_REPLACEMENTS["grpo_accumulated_loss"]
+UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"]
+grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"]
 RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss))
 RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO))
 RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss))
 RL_PRE_ITEMS["grpo_trainer"].append(grpo_compute_loss_slow)
 
+
 # Edit _get_per_token_logps to handle mixed precision
 def grpo_trainer_compute_loss(function_name, function):
-    if  function_name != "compute_loss": return function
+    if function_name != "compute_loss":
+        return function
 
-    def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
+    def compute_loss(
+        self, model, inputs, return_outputs = False, num_items_in_batch = None
+    ):
         if return_outputs:
             raise ValueError("The GRPOTrainer does not support returning outputs")
         # Compute the per-token log probabilities for the model
 
-
         prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
-        completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
-        pixel_values, image_grid_thw = inputs.get("pixel_values", None), inputs.get("image_grid_thw", None)
-        pixel_attention_mask, image_sizes = inputs.get('pixel_attention_mask',None), inputs.get('image_sizes',None)
-        num_items_in_batch  = inputs.get("num_items_in_batch", None)
-        sampling_per_token_logps = inputs.get("sampling_per_token_logps", None)   
+        completion_ids, completion_mask = (
+            inputs["completion_ids"],
+            inputs["completion_mask"],
+        )
+        pixel_values, image_grid_thw = (
+            inputs.get("pixel_values", None),
+            inputs.get("image_grid_thw", None),
+        )
+        pixel_attention_mask, image_sizes = (
+            inputs.get("pixel_attention_mask", None),
+            inputs.get("image_sizes", None),
+        )
+        num_items_in_batch = inputs.get("num_items_in_batch", None)
+        sampling_per_token_logps = inputs.get("sampling_per_token_logps", None)
         current_gradient_accumulation_steps = self.current_gradient_accumulation_steps
         num_processes = self.accelerator.num_processes
 
-        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
+        input_ids = torch.cat([prompt_ids, completion_ids], dim = 1)
         bsz, qlen = input_ids.shape
-        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
+        attention_mask = torch.cat([prompt_mask, completion_mask], dim = 1)
         # attention_mask = None
-        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens
+        logits_to_keep = completion_ids.size(
+            1
+        )  # we only need to compute the logits for the completion tokens
         _input_ids = input_ids
         _logits_to_keep = logits_to_keep
 
-        get_logps_func = \
-            lambda model, input_ids, attention_mask, logits_to_keep, batch_size=None, compute_entropy=False, compute_efficient = False: \
-            self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, compute_efficient) \
-            if hasattr(self, "_get_per_token_logps") else \
-            self._get_per_token_logps_and_entropies(model, input_ids, attention_mask, logits_to_keep, batch_size, compute_entropy, compute_efficient)[0]  # logps
-
-        per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep, compute_efficient = True)
+        get_logps_func = (
+            lambda model,
+            input_ids,
+            attention_mask,
+            logits_to_keep,
+            batch_size = None,
+            compute_entropy = False,
+            compute_efficient = False: self._get_per_token_logps(
+                model, input_ids, attention_mask, logits_to_keep, compute_efficient
+            )
+            if hasattr(self, "_get_per_token_logps")
+            else self._get_per_token_logps_and_entropies(
+                model,
+                input_ids,
+                attention_mask,
+                logits_to_keep,
+                batch_size,
+                compute_entropy,
+                compute_efficient,
+            )[0]
+        )  # logps
+
+        per_token_logps = get_logps_func(
+            model, input_ids, attention_mask, logits_to_keep, compute_efficient = True
+        )
         # Compute the KL divergence between the model and the reference model
         # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
         # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
@@ -597,59 +687,40 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
         input_ids = input_ids[:, -logits_to_keep:]
 
         # Get logit softcapping and logit scale
-        logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) # Gemma
-        if logit_softcapping is None: logit_softcapping = 0
-        logit_scale_multiply = getattr(model.config, "logit_scale", 0) # Cohere
-        if logit_scale_multiply is None: logit_scale_multiply = 0
-        logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite
-        if logit_scale_divide is None: logit_scale_divide = 0
+        logit_softcapping = getattr(model.config, "final_logit_softcapping", 0)  # Gemma
+        if logit_softcapping is None:
+            logit_softcapping = 0
+        logit_scale_multiply = getattr(model.config, "logit_scale", 0)  # Cohere
+        if logit_scale_multiply is None:
+            logit_scale_multiply = 0
+        logit_scale_divide = getattr(model.config, "logits_scaling", 0)  # Granite
+        if logit_scale_divide is None:
+            logit_scale_divide = 0
 
         if per_token_logps is not None:
-
             if ref_hidden_states is not None:
-                ref_hidden_states = ref_hidden_states[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
+                ref_hidden_states = ref_hidden_states[
+                    :, :-1, :
+                ]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
             if old_hidden_states is not None:
-                old_hidden_states = old_hidden_states[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
-            per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
-
-            loss, completion_length, mean_kl, delta, flat_is_ratio = grpo_compute_loss_slow(
-                ref_hidden_states,
-                per_token_logps,
-                old_hidden_states,
-                input_ids,
-                completion_mask,
-                self.beta,
-                advantages,
-                pixel_values = pixel_values,
-                image_grid_thw = image_grid_thw,
-                loss_type = self.args.loss_type,
-                importance_sampling_level = self.importance_sampling_level,
-                epsilon_low = self.epsilon_low,
-                epsilon_high = self.epsilon_high,
-                max_completion_length = self.args.max_completion_length,
-                delta = self.args.delta,
-                temperature = self.args.temperature,
-                logit_softcapping = logit_softcapping,
-                logit_scale_multiply = logit_scale_multiply,
-                logit_scale_divide = logit_scale_divide,
-                num_items_in_batch = num_items_in_batch, 
-                current_gradient_accumulation_steps = current_gradient_accumulation_steps,
-                num_processes = num_processes,
-                sampling_per_token_logps  = sampling_per_token_logps,
-            )
-        else:
-            if hasattr(self.args, "loss_type"):
-                loss, completion_length, mean_kl, delta, flat_is_ratio = grpo_accumulated_loss(
-                    trainer = self,
-                    input_ids = _input_ids,
+                old_hidden_states = old_hidden_states[
+                    :, :-1, :
+                ]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
+            per_token_logps = per_token_logps[
+                :, :-1, :
+            ]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
+
+            loss, completion_length, mean_kl, delta, flat_is_ratio = (
+                grpo_compute_loss_slow(
+                    ref_hidden_states,
+                    per_token_logps,
+                    old_hidden_states,
+                    input_ids,
+                    completion_mask,
+                    self.beta,
+                    advantages,
                     pixel_values = pixel_values,
                     image_grid_thw = image_grid_thw,
-                    logits_to_keep = logits_to_keep,
-                    completion_mask = completion_mask,
-                    advantages = advantages,
-                    old_hidden_states = old_hidden_states,
-                    ref_hidden_states = ref_hidden_states,
-                    n_chunks = self.args.unsloth_num_chunks,
                     loss_type = self.args.loss_type,
                     importance_sampling_level = self.importance_sampling_level,
                     epsilon_low = self.epsilon_low,
@@ -660,11 +731,42 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
                     logit_softcapping = logit_softcapping,
                     logit_scale_multiply = logit_scale_multiply,
                     logit_scale_divide = logit_scale_divide,
-                    attention_mask = attention_mask,
-                    num_items_in_batch = num_items_in_batch, 
+                    num_items_in_batch = num_items_in_batch,
                     current_gradient_accumulation_steps = current_gradient_accumulation_steps,
                     num_processes = num_processes,
-                    sampling_per_token_logps  = sampling_per_token_logps,
+                    sampling_per_token_logps = sampling_per_token_logps,
+                )
+            )
+        else:
+            if hasattr(self.args, "loss_type"):
+                loss, completion_length, mean_kl, delta, flat_is_ratio = (
+                    grpo_accumulated_loss(
+                        trainer = self,
+                        input_ids = _input_ids,
+                        pixel_values = pixel_values,
+                        image_grid_thw = image_grid_thw,
+                        logits_to_keep = logits_to_keep,
+                        completion_mask = completion_mask,
+                        advantages = advantages,
+                        old_hidden_states = old_hidden_states,
+                        ref_hidden_states = ref_hidden_states,
+                        n_chunks = self.args.unsloth_num_chunks,
+                        loss_type = self.args.loss_type,
+                        importance_sampling_level = self.importance_sampling_level,
+                        epsilon_low = self.epsilon_low,
+                        epsilon_high = self.epsilon_high,
+                        max_completion_length = self.args.max_completion_length,
+                        delta = self.args.delta,
+                        temperature = self.args.temperature,
+                        logit_softcapping = logit_softcapping,
+                        logit_scale_multiply = logit_scale_multiply,
+                        logit_scale_divide = logit_scale_divide,
+                        attention_mask = attention_mask,
+                        num_items_in_batch = num_items_in_batch,
+                        current_gradient_accumulation_steps = current_gradient_accumulation_steps,
+                        num_processes = num_processes,
+                        sampling_per_token_logps = sampling_per_token_logps,
+                    )
                 )
             else:
                 # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
@@ -683,8 +785,6 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
                     logit_scale_divide = logit_scale_divide,
                     attention_mask = attention_mask,
                 )
-            pass
-        pass
 
         if "train" in self._metrics:
             mode = "eval" if self.control.should_evaluate else "train"
@@ -695,8 +795,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
             self._metrics["kl"].append(mean_kl.item())
 
         if self.use_vllm and delta is not None:
-            mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=self.model.device)
-            max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=self.model.device)
+            mean_delta = (
+                torch.mean(delta)
+                if delta.numel() > 0
+                else torch.tensor(0.0, device = self.model.device)
+            )
+            max_delta = (
+                torch.max(delta)
+                if delta.numel() > 0
+                else torch.tensor(0.0, device = self.model.device)
+            )
             self._metrics[mode]["sampling/sampling_logp_difference/mean"].append(
                 self.accelerator.gather(mean_delta).mean().item()
             )
@@ -705,13 +813,19 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
             )
 
             min_importance_sampling_ratio = (
-                torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=self.model.device)
+                torch.min(flat_is_ratio)
+                if flat_is_ratio.numel() > 0
+                else torch.tensor(0.0, device = self.model.device)
             )
             mean_importance_sampling_ratio = (
-                torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=self.model.device)
+                torch.mean(flat_is_ratio)
+                if flat_is_ratio.numel() > 0
+                else torch.tensor(0.0, device = self.model.device)
             )
             max_importance_sampling_ratio = (
-                torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=self.model.device)
+                torch.max(flat_is_ratio)
+                if flat_is_ratio.numel() > 0
+                else torch.tensor(0.0, device = self.model.device)
             )
             self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
                 nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item()
@@ -724,56 +838,63 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
             )
 
         return loss
-    pass
 
     function = inspect.getsource(compute_loss)
     return function
-pass
+
+
 RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss)
 
+
 # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356
 # TRL warns if batch size is not a multiple of num_generations -> fix this.
 def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source):
-    if "divisible by the number of generations" not in RLTrainer_source: return ""
-    if "num_generations" not in RLConfig_source: return ""
-
-    check_batch_size = \
-    "div = per_device_train_batch_size // num_generations\n"\
-    "if div * num_generations != per_device_train_batch_size:\n"\
-    "    print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\
-               "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"\
-    "    per_device_train_batch_size = num_generations\n"
+    if "divisible by the number of generations" not in RLTrainer_source:
+        return ""
+    if "num_generations" not in RLConfig_source:
+        return ""
+
+    check_batch_size = (
+        "div = per_device_train_batch_size // num_generations\n"
+        "if div * num_generations != per_device_train_batch_size:\n"
+        "    print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"
+        "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"
+        "    per_device_train_batch_size = num_generations\n"
+    )
     return check_batch_size
-pass
+
+
 RL_CONFIG_CHANGES["grpo_trainer"].append(grpo_trainer_fix_batch_size)
 
 
 # Add other reward function names
 def grpo_trainer_metrics(RLTrainer_source, RLConfig_source):
-    if "reward_funcs" not in RLTrainer_source: return ""
+    if "reward_funcs" not in RLTrainer_source:
+        return ""
 
     # For new TRL we have /mean and /std
     use_mean = "rewards/{reward_func_name}/mean" in RLTrainer_source
-    use_std  = "rewards/{reward_func_name}/std"  in RLTrainer_source
+    use_std = "rewards/{reward_func_name}/std" in RLTrainer_source
     if not use_mean:
         use_normal = "rewards/{reward_func_name}" in RLTrainer_source
     else:
         use_normal = False
-    pass
-
-    log_metrics = \
-    "if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\n"\
-    "else: _reward_funcs = reward_funcs\n"\
-    "for reward_func in _reward_funcs:\n"\
-    "    try:\n"\
-    "        reward_func_name = reward_func.__name__\n"\
-   f"        if {use_mean}:\n"\
-    "            other_metrics.append(f'rewards/{reward_func_name}/mean')\n"\
-   f"        if {use_std}:\n"\
-    "            other_metrics.append(f'rewards/{reward_func_name}/std')\n"\
-   f"        if {use_normal}:\n"\
-    "            other_metrics.append(f'rewards/{reward_func_name}')\n"\
-    "    except: pass\n"
+
+    log_metrics = (
+        "if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\n"
+        "else: _reward_funcs = reward_funcs\n"
+        "for reward_func in _reward_funcs:\n"
+        "    try:\n"
+        "        reward_func_name = reward_func.__name__\n"
+        f"        if {use_mean}:\n"
+        "            other_metrics.append(f'rewards/{reward_func_name}/mean')\n"
+        f"        if {use_std}:\n"
+        "            other_metrics.append(f'rewards/{reward_func_name}/std')\n"
+        f"        if {use_normal}:\n"
+        "            other_metrics.append(f'rewards/{reward_func_name}')\n"
+        "    except: pass\n"
+    )
     return log_metrics
-pass
+
+
 RL_METRICS_CHANGES["grpo_trainer"].append(grpo_trainer_metrics)
diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py
index f6661eee6..9a1d39c9c 100644
--- a/unsloth/models/vision.py
+++ b/unsloth/models/vision.py
@@ -19,12 +19,13 @@
     AutoTokenizer,
     AutoModelForCausalLM,
 )
+
 try:
     from transformers import AutoModelForImageTextToText
+
     AutoModelForVision2Seq = AutoModelForImageTextToText
 except:
     from transformers import AutoModelForVision2Seq
-pass
 from ..kernels import (
     post_patch_loss_function,
 )
@@ -65,12 +66,12 @@
 import re, inspect, sys
 import contextlib
 import types
+
 try:
     from huggingface_hub.utils import get_token
 except:
     # Old HF Hub versions <= 0.0.25
     from huggingface_hub.utils._token import get_token
-pass
 from ..device_type import (
     is_hip,
     get_device_type,
@@ -100,8 +101,10 @@
 ]
 
 from transformers import GenerationConfig, CompileConfig, HybridCache, AutoConfig
+
 try:
     from transformers import PreTrainedConfig
+
     PretrainedConfig = PreTrainedConfig
 except:
     from transformers import PretrainedConfig
@@ -115,7 +118,7 @@
     dynamic = None,
     mode = "reduce-overhead",
 )
-_compile_config.disable = True # Must set manually
+_compile_config.disable = True  # Must set manually
 
 from unsloth_zoo.vllm_utils import (
     convert_lora_modules,
@@ -126,7 +129,7 @@
     torch_compiler_set_stance = torch.compiler.set_stance
 except:
     torch_compiler_set_stance = None
-pass
+
 
 def unsloth_base_fast_generate(
     self,
@@ -150,8 +153,7 @@ def unsloth_base_fast_generate(
         if type(kwargs["key"]) is not torch.Tensor:
             raise TypeError("Unsloth: You need to pass in input_ids to .generate!")
         input_ids = kwargs[key]
-    pass
-    assert(type(input_ids) is torch.Tensor)
+    assert type(input_ids) is torch.Tensor
     bsz = input_ids.shape[0]
 
     FastBaseModel.for_inference(self)
@@ -189,11 +191,8 @@ def unsloth_base_fast_generate(
                     NUM_LOGITS_TO_KEEP[arch] = "logits_to_keep"
                     break
             m = m.model
-        pass
         if arch not in NUM_LOGITS_TO_KEEP:
             NUM_LOGITS_TO_KEEP[arch] = None
-        pass
-    pass
     key = NUM_LOGITS_TO_KEEP[arch]
     if key is not None and key not in kwargs:
         kwargs[key] = 1
@@ -206,8 +205,10 @@ def unsloth_base_fast_generate(
     kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
 
     # Get pixel values for VLMs
-    try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype)
-    except: pass
+    try:
+        kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype)
+    except:
+        pass
 
     # Mixed precision autocast
     if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
@@ -232,7 +233,9 @@ def unsloth_base_fast_generate(
     # Fix generation_config
     # Use hybrid if sliding window seen, otherwise try static
     cache_implementation = getattr(self.config, "cache_implementation", None)
-    if getattr(self, "_supports_static_cache", getattr(self, "_can_compile_fullgraph", True)):
+    if getattr(
+        self, "_supports_static_cache", getattr(self, "_can_compile_fullgraph", True)
+    ):
         if os.environ.get("UNSLOTH_DISABLE_STATIC_GENERATION", "0") == "0":
             cache_implementation = "static"
         elif Version(transformers_version) < Version("4.56.0.dev0"):
@@ -243,9 +246,12 @@ def unsloth_base_fast_generate(
     else:
         cache_implementation = None
     if cache_implementation is not None:
-        swa = getattr(getattr(self.config, "text_config", self.config), "sliding_window", None)
-        if (swa == 0 or type(swa) is not int) \
-            and (getattr(self, "_can_compile_fullgraph", True) is True):
+        swa = getattr(
+            getattr(self.config, "text_config", self.config), "sliding_window", None
+        )
+        if (swa == 0 or type(swa) is not int) and (
+            getattr(self, "_can_compile_fullgraph", True) is True
+        ):
             cache_implementation = "static"
         else:
             if Version(transformers_version) < Version("4.56.0.dev0"):
@@ -261,18 +267,20 @@ def unsloth_base_fast_generate(
         kwargs["cache_implementation"] = cache_implementation
         if cache_implementation is not None:
             kwargs["compile_config"] = _compile_config
-    pass
 
     # Delete cached Flex Attention masks to reset inference
     for name, module in self.named_modules():
         if hasattr(module, "_flex_attention_cache"):
-            try: del module._flex_attention_cache
-            except: pass
+            try:
+                del module._flex_attention_cache
+            except:
+                pass
         # Solves AttributeError: 'SlidingWindowLayer' object has no attribute 'max_batch_size'
         if hasattr(module, "_cache") and "cache_utils" in str(module._cache.__class__):
-            try: del module._cache
-            except: pass
-    pass
+            try:
+                del module._cache
+            except:
+                pass
 
     # DO INFERENCE
     with torch.inference_mode(), autocaster:
@@ -281,54 +289,57 @@ def unsloth_base_fast_generate(
     # Delete cached Flex Attention masks to reset inference
     for name, module in self.named_modules():
         if hasattr(module, "_flex_attention_cache"):
-            try: del module._flex_attention_cache
-            except: pass
+            try:
+                del module._flex_attention_cache
+            except:
+                pass
         # Solves AttributeError: 'SlidingWindowLayer' object has no attribute 'max_batch_size'
         if hasattr(module, "_cache") and "cache_utils" in str(module._cache.__class__):
-            try: del module._cache
-            except: pass
-    pass
+            try:
+                del module._cache
+            except:
+                pass
 
     # FastBaseModel.for_training(self)
     return output
-pass
 
-class FastBaseModel:
 
+class FastBaseModel:
     @staticmethod
     def from_pretrained(
-        model_name        = "unsloth/Llama-3.2-1B-Instruct",
-        max_seq_length    = 2048,
-        dtype             = None,
-        load_in_4bit      = True,
-        load_in_8bit      = False,
-        load_in_16bit     = False,
-        full_finetuning   = False,
-        token             = None,
-        device_map        = "sequential",
+        model_name = "unsloth/Llama-3.2-1B-Instruct",
+        max_seq_length = 2048,
+        dtype = None,
+        load_in_4bit = True,
+        load_in_8bit = False,
+        load_in_16bit = False,
+        full_finetuning = False,
+        token = None,
+        device_map = "sequential",
         trust_remote_code = False,
-        model_types       = None,
-        tokenizer_name    = None,
-        auto_model        = AutoModelForVision2Seq,
+        model_types = None,
+        tokenizer_name = None,
+        auto_model = AutoModelForVision2Seq,
         use_gradient_checkpointing = "unsloth",
-        supports_sdpa     = True,
-        whisper_language  = None,
-        whisper_task      = None,
-        auto_config       = None,
+        supports_sdpa = True,
+        whisper_language = None,
+        whisper_task = None,
+        auto_config = None,
         offload_embedding = False,
         # vLLM parameters
-        fast_inference    = False,
+        fast_inference = False,
         gpu_memory_utilization = 0.5,
-        float8_kv_cache   = False,
-        random_state      = 3407,
-        max_lora_rank     = 64,
+        float8_kv_cache = False,
+        random_state = 3407,
+        max_lora_rank = 64,
         disable_log_stats = False,
         unsloth_vllm_standby = False,
         **kwargs,
     ):
         if unsloth_vllm_standby and os.environ.get("UNSLOTH_VLLM_STANDBY", "0") != "1":
-            raise RuntimeError("Unsloth: UNSLOTH_VLLM_STANDBY is True, but UNSLOTH_VLLM_STANDBY is not set to 1!")
-        pass
+            raise RuntimeError(
+                "Unsloth: UNSLOTH_VLLM_STANDBY is True, but UNSLOTH_VLLM_STANDBY is not set to 1!"
+            )
 
         if model_types is None:
             raise RuntimeError(
@@ -337,14 +348,15 @@ def from_pretrained(
         if os.environ.get("UNSLOTH_MODEL_NAME", "") == "":
             os.environ["UNSLOTH_MODEL_NAME"] = model_name.lower()
 
-        is_vlm = (auto_model in [AutoModelForVision2Seq, AutoModelForImageTextToText])
-        is_whisper = (whisper_language is not None and whisper_task is not None)
+        is_vlm = auto_model in [AutoModelForVision2Seq, AutoModelForImageTextToText]
+        is_whisper = whisper_language is not None and whisper_task is not None
         auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizer
 
         model_type_arch = model_types[0]
         if model_type_arch == "siglip":
             for model_type_arch in model_types:
-                if model_type_arch != "siglip": break
+                if model_type_arch != "siglip":
+                    break
 
         vllm_enable_lora = True
 
@@ -365,30 +377,40 @@ def from_pretrained(
         os.environ["UNSLOTH_USE_NEW_MODEL"] = "1"
         if trust_remote_code:
             print(
-                "Unsloth: WARNING `trust_remote_code` is True.\n"\
+                "Unsloth: WARNING `trust_remote_code` is True.\n"
                 "Are you certain you want to do remote code execution?"
             )
-        pass
-        if token is None: token = get_token()
+        if token is None:
+            token = get_token()
         SUPPORTS_BFLOAT16 = is_bfloat16_supported()
 
         if DEVICE_TYPE == "cuda":
             gpu_stats = torch.cuda.get_device_properties(0)
-            gpu_stats_name = gpu_stats.name + ". " if gpu_stats.name != "" else "NVIDIA GPU Device. "
+            gpu_stats_name = (
+                gpu_stats.name + ". " if gpu_stats.name != "" else "NVIDIA GPU Device. "
+            )
             gpu_version = torch.version.cuda
             gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}."
-            try:    vllm_version = f" vLLM: {importlib_version('vllm')}."
-            except: vllm_version = ""
+            try:
+                vllm_version = f" vLLM: {importlib_version('vllm')}."
+            except:
+                vllm_version = ""
         elif DEVICE_TYPE == "hip":
             gpu_stats = torch.cuda.get_device_properties(0)
-            gpu_stats_name = gpu_stats.name + ". " if gpu_stats.name != "" else "AMD GPU Device. "
+            gpu_stats_name = (
+                gpu_stats.name + ". " if gpu_stats.name != "" else "AMD GPU Device. "
+            )
             gpu_version = torch.version.hip
             gpu_stats_snippet = f"ROCm Toolkit: {gpu_version}."
-            try:    vllm_version = f" vLLM: {importlib_version('vllm')}."
-            except: vllm_version = ""
+            try:
+                vllm_version = f" vLLM: {importlib_version('vllm')}."
+            except:
+                vllm_version = ""
         elif DEVICE_TYPE == "xpu":
             gpu_stats = torch.xpu.get_device_properties(0)
-            gpu_stats_name = gpu_stats.name + ". " if gpu_stats.name != "" else "Intel XPU Device. "
+            gpu_stats_name = (
+                gpu_stats.name + ". " if gpu_stats.name != "" else "Intel XPU Device. "
+            )
             gpu_version = torch.version.xpu
             gpu_stats_snippet = f"Intel Toolkit: {gpu_version}."
             # [TODO] After adding vLLM support for XPU, change this
@@ -398,26 +420,31 @@ def from_pretrained(
 
         max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
 
-        statistics = \
-        f"==((====))==  Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\
-        f"   {chr(92)}{chr(92)}   /|    {gpu_stats_name}Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\
-        f"O^O/ {chr(92)}_/ {chr(92)}    Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"\
-        f"{chr(92)}        /    Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
-        f' "-____-"     Free license: http://github.com/unslothai/unsloth'
+        statistics = (
+            f"==((====))==  Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"
+            f"   {chr(92)}{chr(92)}   /|    {gpu_stats_name}Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"
+            f"O^O/ {chr(92)}_/ {chr(92)}    Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"
+            f"{chr(92)}        /    Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"
+            f' "-____-"     Free license: http://github.com/unslothai/unsloth'
+        )
 
         print(statistics)
 
         # Warn about fast transfers
         if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ:
             old_hf_transfer = os.environ["HF_HUB_ENABLE_HF_TRANSFER"]
-            if old_hf_transfer in ("False", "false"): old_hf_transfer = "0"
-            if old_hf_transfer in ("True",  "true" ): old_hf_transfer = "1"
+            if old_hf_transfer in ("False", "false"):
+                old_hf_transfer = "0"
+            if old_hf_transfer in ("True", "true"):
+                old_hf_transfer = "1"
         else:
             old_hf_transfer = "0"
         if old_hf_transfer == "1":
-            print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!")
-        pass
-        if old_hf_transfer != "0": os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+            print(
+                "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!"
+            )
+        if old_hf_transfer != "0":
+            os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 
         # For debugging - we use a download counter to see if environments are not breaking or if HF is down
         get_statistics(kwargs.get("local_files_only", False))
@@ -425,20 +452,23 @@ def from_pretrained(
         if dtype is None:
             dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
         elif os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
-            if dtype == torch.float16: dtype = torch.bfloat16
+            if dtype == torch.float16:
+                dtype = torch.bfloat16
         elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
-            logger.warning_once("Device does not support bfloat16. Will change to float16.")
+            logger.warning_once(
+                "Device does not support bfloat16. Will change to float16."
+            )
             dtype = torch.float16
-        pass
-        assert(dtype in (torch.float16, torch.bfloat16, torch.float32))
+        assert dtype in (torch.float16, torch.bfloat16, torch.float32)
 
         bnb_compute_dtype = dtype
         do_forced_float32 = False
         if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
-            print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.")
+            print(
+                f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32."
+            )
             bnb_compute_dtype = torch.float16
             do_forced_float32 = True
-        pass
 
         # Check for custom data-types
         custom_datatype = None
@@ -446,13 +476,17 @@ def from_pretrained(
         if os.environ.get("UNSLOTH_FORCE_CUSTOM_DTYPE", "") != "":
             custom_datatype = os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"]
             assert custom_datatype.count(";") >= 4
-            checker, _dtype, _bnb_compute_dtype, _custom_datatype, execute_code = custom_datatype.split(";", 4)
+            checker, _dtype, _bnb_compute_dtype, _custom_datatype, execute_code = (
+                custom_datatype.split(";", 4)
+            )
             # Allow custom dtypes on all runs
-            allow_all_runs = (checker == "all")
+            allow_all_runs = checker == "all"
             # Allow only on float16 datatypes
             allow_float16_runs = (
-                (checker == "float16" or checker == "torch.float16") and \
-                (dtype == torch.float16 or os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1")
+                checker == "float16" or checker == "torch.float16"
+            ) and (
+                dtype == torch.float16
+                or os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1"
             )
             if allow_all_runs or allow_float16_runs:
                 if eval(_dtype) is not None:
@@ -467,55 +501,62 @@ def from_pretrained(
             else:
                 custom_datatype = None
                 correct_dtype = None
-        pass
 
         # Stop SDPA for some archs like Pixtral / Mistral3
         if not ("attn_implementation" in kwargs):
             kwargs["attn_implementation"] = "sdpa"
         if not supports_sdpa:
             if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "0") == "0":
-                print(f"Unsloth: {model_type_arch.title()} does not support SDPA - switching to fast eager.")
+                print(
+                    f"Unsloth: {model_type_arch.title()} does not support SDPA - switching to fast eager."
+                )
             del kwargs["attn_implementation"]
-        pass
 
         bnb_config = None
         if full_finetuning and (load_in_4bit or load_in_8bit):
-            print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.")
-            load_in_4bit  = False
-            load_in_8bit  = False
+            print(
+                "Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA."
+            )
+            load_in_4bit = False
+            load_in_8bit = False
             load_in_16bit = False
-        pass
 
         if int(load_in_4bit) + int(load_in_8bit) + int(load_in_16bit) >= 2:
-            raise RuntimeError("Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!")
+            raise RuntimeError(
+                "Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!"
+            )
         if load_in_4bit:
             bnb_config = BitsAndBytesConfig(
-                load_in_4bit              = True,
+                load_in_4bit = True,
                 bnb_4bit_use_double_quant = True,
-                bnb_4bit_quant_type       = "nf4",
-                bnb_4bit_compute_dtype    = bnb_compute_dtype,
-                llm_int8_skip_modules     = SKIP_QUANTIZATION_MODULES.copy(),
+                bnb_4bit_quant_type = "nf4",
+                bnb_4bit_compute_dtype = bnb_compute_dtype,
+                llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
             )
         elif load_in_8bit:
             bnb_config = BitsAndBytesConfig(
-                load_in_8bit              = True,
-                llm_int8_skip_modules     = SKIP_QUANTIZATION_MODULES.copy(),
+                load_in_8bit = True,
+                llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
             )
         elif load_in_16bit:
             bnb_config = None
         elif not load_in_4bit and not load_in_8bit and not full_finetuning:
-            print("Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.")
-        pass
+            print(
+                "Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA."
+            )
 
         if full_finetuning:
             os.environ["UNSLOTH_ENABLE_FULL_FINETUNING"] = "1"
             if dtype == torch.bfloat16:
-                print("Unsloth: Using bfloat16 full finetuning which cuts memory usage by 50%.")
+                print(
+                    "Unsloth: Using bfloat16 full finetuning which cuts memory usage by 50%."
+                )
             else:
-                print("Unsloth: Float16 full finetuning uses more memory since we upcast weights to float32.")
+                print(
+                    "Unsloth: Float16 full finetuning uses more memory since we upcast weights to float32."
+                )
         else:
             os.environ["UNSLOTH_ENABLE_FULL_FINETUNING"] = "0"
-        pass
 
         # Fix AttributeError: 'BitsAndBytesConfig' object has no attribute 'get_loading_attributes'
         if bnb_config is not None and not hasattr(bnb_config, "get_loading_attributes"):
@@ -524,7 +565,10 @@ def from_pretrained(
         # Cannot be None, since HF now checks for the config
         if load_in_4bit or load_in_8bit:
             # Ignore load_in_4bit / load_in_8bit for MXFP4 - best to get config file
-            if "gpt-oss-20b" in model_name.lower() or "gpt-oss-120b" in model_name.lower():
+            if (
+                "gpt-oss-20b" in model_name.lower()
+                or "gpt-oss-120b" in model_name.lower()
+            ):
                 pass
             else:
                 kwargs["quantization_config"] = bnb_config
@@ -536,28 +580,40 @@ def from_pretrained(
                     trust_remote_code = trust_remote_code,
                 )
             if hasattr(auto_config, "quantization_config"):
-                from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING
+                from transformers.quantizers.auto import (
+                    AUTO_QUANTIZATION_CONFIG_MAPPING,
+                )
+
                 quantization_config = auto_config.quantization_config
                 quant_method = quantization_config["quant_method"]
                 # Sometimes bitsandbytes_4bit + bitsandbytes_8bit is provided
-                if quant_method == "bitsandbytes" and "bitsandbytes" not in AUTO_QUANTIZATION_CONFIG_MAPPING:
+                if (
+                    quant_method == "bitsandbytes"
+                    and "bitsandbytes" not in AUTO_QUANTIZATION_CONFIG_MAPPING
+                ):
                     if "bitsandbytes_4bit" not in AUTO_QUANTIZATION_CONFIG_MAPPING:
-                        raise KeyError("Unsloth: AUTO_QUANTIZATION_CONFIG_MAPPING does not have `bitsandbytes_4bit`")
+                        raise KeyError(
+                            "Unsloth: AUTO_QUANTIZATION_CONFIG_MAPPING does not have `bitsandbytes_4bit`"
+                        )
                     quantizer = AUTO_QUANTIZATION_CONFIG_MAPPING["bitsandbytes_4bit"]
                 else:
                     quantizer = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
                 quantizer_kwargs = {}
                 # We cannot dequantize since gpt-oss-20b MXFP4 will now be gpt-oss-20b-BF16
-                if load_in_16bit and "dequantize" in inspect.signature(quantizer).parameters:
+                if (
+                    load_in_16bit
+                    and "dequantize" in inspect.signature(quantizer).parameters
+                ):
                     quantizer_kwargs["dequantize"] = True
-                quantization_config = quantizer.from_dict(quantization_config, **quantizer_kwargs)
+                quantization_config = quantizer.from_dict(
+                    quantization_config, **quantizer_kwargs
+                )
                 kwargs["quantization_config"] = quantization_config
-            pass
-        pass
 
         # Check if using forced float32 - we load it in bfloat16, then cast to float16!
         torch_dtype = dtype
-        if do_forced_float32: torch_dtype = torch.bfloat16
+        if do_forced_float32:
+            torch_dtype = torch.bfloat16
 
         kwargs = add_dtype_kwargs(torch_dtype, kwargs)
 
@@ -565,11 +621,11 @@ def from_pretrained(
         if not fast_inference:
             model = auto_model.from_pretrained(
                 model_name,
-                device_map              = device_map,
+                device_map = device_map,
                 # torch_dtype           = torch_dtype, # Transformers removed torch_dtype
                 # quantization_config   = bnb_config,
-                token                   = token,
-                trust_remote_code       = trust_remote_code,
+                token = token,
+                trust_remote_code = trust_remote_code,
                 # attn_implementation   = attn_implementation,
                 **kwargs,
             )
@@ -603,6 +659,7 @@ def from_pretrained(
                 convert_vllm_to_huggingface,
                 generate_batches,
             )
+
             model_config = AutoConfig.from_pretrained(
                 model_name,
                 token = token,
@@ -611,27 +668,28 @@ def from_pretrained(
             model_config.model_name = model_name
 
             if fast_inference:
-                fast_inference, model_name = fast_inference_setup(model_name, model_config)
+                fast_inference, model_name = fast_inference_setup(
+                    model_name, model_config
+                )
 
             allowed_args = inspect.getfullargspec(load_vllm).args
             load_vllm_kwargs = dict(
-                model_name             = model_name,
-                config                 = model_config,
+                model_name = model_name,
+                config = model_config,
                 gpu_memory_utilization = gpu_memory_utilization,
-                max_seq_length         = max_seq_length,
-                dtype                  = dtype,
-                float8_kv_cache        = float8_kv_cache,
-                enable_lora            = vllm_enable_lora,
-                max_lora_rank          = max_lora_rank,
-                disable_log_stats      = disable_log_stats,
-                use_bitsandbytes       = load_in_4bit,
-                unsloth_vllm_standby   = unsloth_vllm_standby,
-                is_vision_model        = is_vlm,
+                max_seq_length = max_seq_length,
+                dtype = dtype,
+                float8_kv_cache = float8_kv_cache,
+                enable_lora = vllm_enable_lora,
+                max_lora_rank = max_lora_rank,
+                disable_log_stats = disable_log_stats,
+                use_bitsandbytes = load_in_4bit,
+                unsloth_vllm_standby = unsloth_vllm_standby,
+                is_vision_model = is_vlm,
             )
             for allowed_arg in allowed_args:
                 if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:
                     load_vllm_kwargs[allowed_arg] = kwargs[allowed_arg]
-            pass
 
             # Load vLLM first
             llm = load_vllm(**load_vllm_kwargs)
@@ -645,13 +703,15 @@ def from_pretrained(
             model = convert_vllm_to_huggingface(
                 quant_state_dict,
                 model_config,
-                dtype, bnb_config,
+                dtype,
+                bnb_config,
                 is_vision_model = is_vlm,
             )
             model.vllm_engine = llm
             model.fast_generate = model.vllm_engine.generate
-            model.fast_generate_batches = functools.partial(generate_batches, model.vllm_engine)
-        pass
+            model.fast_generate_batches = functools.partial(
+                generate_batches, model.vllm_engine
+            )
 
         raise_handler.remove()
 
@@ -661,48 +721,49 @@ def from_pretrained(
         # Check float32 norm weights
         if os.environ.get("UNSLOTH_HIGH_PRECISION_LAYERNORM", "0") == "1":
             for jj, (name, module) in enumerate(model.named_modules()):
-                if (name.endswith(("norm", "norm1", "norm2", "norm3", "norm4")) \
-                    or "layernorm" in name or "layer_norm" in name) \
-                    and hasattr(module, "weight"):
+                if (
+                    name.endswith(("norm", "norm1", "norm2", "norm3", "norm4"))
+                    or "layernorm" in name
+                    or "layer_norm" in name
+                ) and hasattr(module, "weight"):
                     module._pre_set_compute_dtype = torch.float32
-        pass
         # Edit data-types
         if custom_datatype is not None:
             with torch.no_grad():
                 for jj, (name, module) in enumerate(model.named_modules()):
                     exec(custom_datatype)
-                pass
-            pass
-        pass
         # Clear deleted GPU items
         for _ in range(3):
             gc.collect()
-            if DEVICE_TYPE in ("cuda", "hip"):  torch.cuda.empty_cache()
-            elif DEVICE_TYPE == "xpu": torch.xpu.empty_cache()
-        pass
+            if DEVICE_TYPE in ("cuda", "hip"):
+                torch.cuda.empty_cache()
+            elif DEVICE_TYPE == "xpu":
+                torch.xpu.empty_cache()
 
         # Counteract saved tokenizers
         tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
-        if (whisper_language and whisper_task) or auto_model.__name__.endswith("ForConditionalGeneration"):
-           tokenizer = auto_processor.from_pretrained(
+        if (whisper_language and whisper_task) or auto_model.__name__.endswith(
+            "ForConditionalGeneration"
+        ):
+            tokenizer = auto_processor.from_pretrained(
                 tokenizer_name,
                 padding_side = "left",
-                token        = token,
-                language     = whisper_language,
-                task         = whisper_task,
+                token = token,
+                language = whisper_language,
+                task = whisper_task,
             )
         else:
             try:
                 tokenizer = auto_processor.from_pretrained(
                     tokenizer_name,
                     padding_side = "left",
-                    token        = token,
+                    token = token,
                 )
             except:
                 tokenizer = get_auto_processor(
                     tokenizer_name,
                     padding_side = "left",
-                    token        = token,
+                    token = token,
                 )
         if hasattr(tokenizer, "tokenizer"):
             __tokenizer = tokenizer.tokenizer
@@ -710,15 +771,14 @@ def from_pretrained(
             __tokenizer.padding_side = "left"
             # Check bos, eos, pad tokens
             if hasattr(__tokenizer, "bos_token"):
-                tokenizer.bos_token    = __tokenizer.bos_token
+                tokenizer.bos_token = __tokenizer.bos_token
                 tokenizer.bos_token_id = __tokenizer.bos_token_id
             if hasattr(__tokenizer, "eos_token"):
-                tokenizer.eos_token    = __tokenizer.eos_token
+                tokenizer.eos_token = __tokenizer.eos_token
                 tokenizer.eos_token_id = __tokenizer.eos_token_id
             if hasattr(__tokenizer, "pad_token"):
-                tokenizer.pad_token    = __tokenizer.pad_token
+                tokenizer.pad_token = __tokenizer.pad_token
                 tokenizer.pad_token_id = __tokenizer.pad_token_id
-        pass
         # Fix other stuff like BnB compute data types
         model, tokenizer = patch_model_and_tokenizer(
             model,
@@ -733,22 +793,24 @@ def from_pretrained(
 
         # Log Unsloth version for future fastpaths for inference
         if hasattr(model, "config"):
-            model.config.update({"unsloth_version" : __version__})
-        pass
+            model.config.update({"unsloth_version": __version__})
         patch_saving_functions(model, vision = True)
         if tokenizer is None:
             del model
-            raise RuntimeError("Unsloth: The tokenizer is weirdly not loaded? Please check if there is one.")
+            raise RuntimeError(
+                "Unsloth: The tokenizer is weirdly not loaded? Please check if there is one."
+            )
         patch_saving_functions(tokenizer, vision = True)
 
         # Fix gradient accumulation
         from transformers.trainer import Trainer
+
         patch_gradient_accumulation_fix(Trainer)
 
         # Save tokenizer for inference purposes
-        tokenizer.padding_side = "left" # Force inference
+        tokenizer.padding_side = "left"  # Force inference
         if hasattr(tokenizer, "tokenizer"):
-            tokenizer.tokenizer.padding_side = "left" # Force inference
+            tokenizer.tokenizer.padding_side = "left"  # Force inference
         m = model
         while hasattr(m, "model"):
             m.max_seq_length = max_seq_length
@@ -756,7 +818,6 @@ def from_pretrained(
             # Also set is_loaded_in_8bit to disable incorrect DDP
             m.is_loaded_in_8bit = True if not full_finetuning else False
             m = m.model
-        pass
         m.max_seq_length = max_seq_length
         # Save to modules as well
         for module in model.modules():
@@ -766,18 +827,19 @@ def from_pretrained(
         m.is_loaded_in_8bit = True if not full_finetuning else False
 
         # Patch generate
-        if os.environ.get("UNSLOTH_DISABLE_FAST_GENERATION", "0") == "0" and hasattr(model, 'generate'):
+        if os.environ.get("UNSLOTH_DISABLE_FAST_GENERATION", "0") == "0" and hasattr(
+            model, "generate"
+        ):
             if model.generate.__name__ != "unsloth_base_fast_generate":
                 model._old_generate = model.generate
                 unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__
                 model.generate = types.MethodType(unsloth_base_fast_generate, model)
-        pass
         model._unsloth_trust_remote_code = trust_remote_code
         # Post patches
         model = FastBaseModel.post_patch_model(
             model,
             use_gradient_checkpointing = use_gradient_checkpointing,
-            trust_remote_code  = trust_remote_code,
+            trust_remote_code = trust_remote_code,
             model_type = model_type_arch,
             tokenizer = tokenizer,
         )
@@ -788,40 +850,39 @@ def from_pretrained(
                 torch.cuda.empty_cache()
             elif DEVICE_TYPE == "xpu":
                 torch.xpu.empty_cache()
-        pass
         return model, tokenizer
-    pass
 
     @staticmethod
     def get_peft_model(
         model,
-        r                          = 16,
-        target_modules             = None,
-        lora_alpha                 = 16,
-        lora_dropout               = 0.0,
-        bias                       = "none",
-        finetune_vision_layers     = True,
-        finetune_language_layers   = True,
+        r = 16,
+        target_modules = None,
+        lora_alpha = 16,
+        lora_dropout = 0.0,
+        bias = "none",
+        finetune_vision_layers = True,
+        finetune_language_layers = True,
         finetune_attention_modules = True,
-        finetune_mlp_modules       = True,
-        layers_to_transform        = None,
-        layers_pattern             = None,
+        finetune_mlp_modules = True,
+        layers_to_transform = None,
+        layers_pattern = None,
         use_gradient_checkpointing = "unsloth",
-        random_state               = 3407,
-        max_seq_length             = 2048, # not used anymore
-        use_rslora                 = False,
-        modules_to_save            = None,
-        init_lora_weights          = True,
-        loftq_config               = {},
-        task_type                  = TaskType.CAUSAL_LM,
-        temporary_location         = "_unsloth_temporary_saved_buffers",
-        qat_scheme                 = None,
-        **kwargs
+        random_state = 3407,
+        max_seq_length = 2048,  # not used anymore
+        use_rslora = False,
+        modules_to_save = None,
+        init_lora_weights = True,
+        loftq_config = {},
+        task_type = TaskType.CAUSAL_LM,
+        temporary_location = "_unsloth_temporary_saved_buffers",
+        qat_scheme = None,
+        **kwargs,
     ):
         if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1":
-            print("Unsloth: Full finetuning is enabled, so .get_peft_model has no effect")
+            print(
+                "Unsloth: Full finetuning is enabled, so .get_peft_model has no effect"
+            )
             return model
-        pass
         transformers_set_seed(random_state)
 
         if type(r) is not int:
@@ -830,28 +891,39 @@ def get_peft_model(
             raise TypeError(f"Unsloth: Rank of {str(r)} must be larger than 0.")
 
         if isinstance(model, PeftModelForCausalLM):
-            raise RuntimeError("Unsloth: You already added LoRA adapters to your model!")
+            raise RuntimeError(
+                "Unsloth: You already added LoRA adapters to your model!"
+            )
 
         if target_modules == "all-linear":
-            finetune_vision_layers     = True
-            finetune_language_layers   = True
+            finetune_vision_layers = True
+            finetune_language_layers = True
             finetune_attention_modules = True
-            finetune_mlp_modules       = True
-        pass
+            finetune_mlp_modules = True
         if target_modules is None or target_modules == "all-linear":
             target_modules = get_peft_regex(
                 model,
-                finetune_vision_layers     = finetune_vision_layers,
-                finetune_language_layers   = finetune_language_layers,
+                finetune_vision_layers = finetune_vision_layers,
+                finetune_language_layers = finetune_language_layers,
                 finetune_attention_modules = finetune_attention_modules,
-                finetune_mlp_modules       = finetune_mlp_modules,
+                finetune_mlp_modules = finetune_mlp_modules,
             )
         else:
-            assert(type(target_modules) in (list, tuple, str,))
-        pass
+            assert type(target_modules) in (
+                list,
+                tuple,
+                str,
+            )
 
         if hasattr(model, "vllm_engine"):
-            if hasattr(model.vllm_engine, "llm_engine") and hasattr(model.vllm_engine.llm_engine, "vllm_config") and getattr(model.vllm_engine.llm_engine.vllm_config, "lora_config", None) is None:
+            if (
+                hasattr(model.vllm_engine, "llm_engine")
+                and hasattr(model.vllm_engine.llm_engine, "vllm_config")
+                and getattr(
+                    model.vllm_engine.llm_engine.vllm_config, "lora_config", None
+                )
+                is None
+            ):
                 # If vLLM is being used but lora is not enabled, throw an error
                 # Ref https://github.com/vllm-project/vllm/blob/51ba839555a5d122eadd91e9c16463ac288f5fa1/vllm/v1/engine/processor.py#L148-L151
                 raise RuntimeError("Unsloth: LoRA is not enabled for this model!")
@@ -859,13 +931,17 @@ def get_peft_model(
                 # vLLM does not support LoRA on vision layers
                 # https://github.com/vllm-project/vllm/blob/main/vllm/lora/models.py#L471-L477
                 # TODO: Update this once vLLM V1 supports LoRA on vision layers (possibly not happening)
-                raise RuntimeError("Unsloth: Finetuning vision layers is not supported for fast_inference. Only text layers are supported!")
+                raise RuntimeError(
+                    "Unsloth: Finetuning vision layers is not supported for fast_inference. Only text layers are supported!"
+                )
             if model.config.model_type in VLLM_NON_LORA_VLM:
                 # mllama is still only in vllm v0 https://arc.net/l/quote/llwkfgmu
                 # https://docs.vllm.ai/en/stable/models/supported_models.html#text-generation_1
                 # vLLM V0 does not support LoRA on multi modal models.
                 # TODO: Update this once vLLM V1 supports Llama 3.2 aka mllama
-                raise RuntimeError("Unsloth: LoRA finetuning for Llama 3.2 aka mllama models is not supported with fast_inference!")
+                raise RuntimeError(
+                    "Unsloth: LoRA finetuning for Llama 3.2 aka mllama models is not supported with fast_inference!"
+                )
 
         # Clear deleted GPU items
         for _ in range(3):
@@ -874,17 +950,21 @@ def get_peft_model(
                 torch.cuda.empty_cache()
             elif DEVICE_TYPE == "xpu":
                 torch.xpu.empty_cache()
-        pass
         max_seq_length = model.max_seq_length
         # If we pass loftq_config = None we will get an error
-        loftq_config = validate_loftq_config(loftq_config, lora_dropout, bias, init_lora_weights, model)
+        loftq_config = validate_loftq_config(
+            loftq_config, lora_dropout, bias, init_lora_weights, model
+        )
 
         # Get only allowed parameters for LoraConfig
-        local_variables = { **locals(), **kwargs, }
+        local_variables = {
+            **locals(),
+            **kwargs,
+        }
         del local_variables["kwargs"]
         allowed_parameters = inspect.signature(LoraConfig).parameters.keys()
         lora_config = LoraConfig(
-            **{ k : v for k, v in local_variables.items() if k in allowed_parameters },
+            **{k: v for k, v in local_variables.items() if k in allowed_parameters},
         )
         model = prepare_model_for_kbit_training(
             model,
@@ -895,7 +975,6 @@ def get_peft_model(
         if qat_scheme is not None:
             print("Unsloth: Applying QAT to mitigate quantization degradation")
             model = _prepare_model_for_qat(model, qat_scheme)
-        pass
         # Fix LoraConfig.auto_mapping is None
         fix_lora_auto_mapping(model)
         # Enable gradients on modules which are trainable
@@ -917,21 +996,18 @@ def get_peft_model(
                 torch.cuda.empty_cache()
             elif DEVICE_TYPE == "xpu":
                 torch.xpu.empty_cache()
-        pass
         patch_saving_functions(model, vision = True)
         patch_peft_fast_inference(model)
 
         # Add for_inference and for_training
-        model.for_training  = functools.partial(FastBaseModel.for_training,  model)
+        model.for_training = functools.partial(FastBaseModel.for_training, model)
         model.for_inference = functools.partial(FastBaseModel.for_inference, model)
         m = model
         while hasattr(m, "model"):
-            m.for_training  = functools.partial(FastBaseModel.for_training,  m)
+            m.for_training = functools.partial(FastBaseModel.for_training, m)
             m.for_inference = functools.partial(FastBaseModel.for_inference, m)
             m = m.model
         return model
-    pass
-
 
     @staticmethod
     def post_patch_model(
@@ -944,26 +1020,32 @@ def post_patch_model(
         full_finetuning = os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1"
 
         float32_mixed_precision = True
-        if _get_dtype(dtype_from_config(model.config)) == torch.bfloat16 and full_finetuning:
+        if (
+            _get_dtype(dtype_from_config(model.config)) == torch.bfloat16
+            and full_finetuning
+        ):
             # Use bfloat16 precision for full finetuning
             float32_mixed_precision = False
 
         model = prepare_model_for_training(
             model,
             use_gradient_checkpointing = use_gradient_checkpointing,
-            use_reentrant              = True,
-            full_finetuning            = full_finetuning,
-            train_layernorms           = full_finetuning,
-            train_embedding            = full_finetuning,
-            train_lm_head              = full_finetuning,
-            float32_mixed_precision    = float32_mixed_precision,
-            patch_modules_to_save      = True,
+            use_reentrant = True,
+            full_finetuning = full_finetuning,
+            train_layernorms = full_finetuning,
+            train_embedding = full_finetuning,
+            train_lm_head = full_finetuning,
+            float32_mixed_precision = float32_mixed_precision,
+            patch_modules_to_save = True,
         )
 
         from transformers.trainer import Trainer
-        if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop" and trust_remote_code == False:
-            raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop')
-        pass
+
+        if (
+            Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop"
+            and trust_remote_code == False
+        ):
+            raise RuntimeError("Unsloth: Unsuccessfully patched inner_training_loop")
         patch_saving_functions(model, vision = True)
 
         # Patch tokenizer to pad to the left
@@ -972,15 +1054,12 @@ def post_patch_model(
             if hasattr(m, "_saved_temp_tokenizer"):
                 if hasattr(m._saved_temp_tokenizer, "tokenizer"):
                     m._saved_temp_tokenizer.tokenizer.padding_side = "left"
-            pass
             # Also set is_loaded_in_8bit to disable incorrect DDP
             m.is_loaded_in_8bit = True if not full_finetuning else False
             m = m.model
-        pass
         if hasattr(m, "_saved_temp_tokenizer"):
             if hasattr(m._saved_temp_tokenizer, "tokenizer"):
                 m._saved_temp_tokenizer.tokenizer.padding_side = "left"
-        pass
         # Also set is_loaded_in_8bit to disable incorrect DDP
         m.is_loaded_in_8bit = True if not full_finetuning else False
 
@@ -991,64 +1070,74 @@ def post_patch_model(
                 torch.cuda.empty_cache()
             elif DEVICE_TYPE == "xpu":
                 torch.xpu.empty_cache()
-        pass
         # Add for_inference and for_training
-        model.for_training  = functools.partial(FastBaseModel.for_training,  model)
+        model.for_training = functools.partial(FastBaseModel.for_training, model)
         model.for_inference = functools.partial(FastBaseModel.for_inference, model)
         m = model
         while hasattr(m, "model"):
-            m.for_training  = functools.partial(FastBaseModel.for_training,  m)
+            m.for_training = functools.partial(FastBaseModel.for_training, m)
             m.for_inference = functools.partial(FastBaseModel.for_inference, m)
             m = m.model
         # Set weight[padding_idx] = 0
         # Only do this if tokenizer is defined since eos_token == pad_token sometimes!
         pad_token_id = getattr(tokenizer, "pad_token_id", None)
-        if tokenizer is not None and getattr(tokenizer, "eos_token_id", None) != pad_token_id:
+        if (
+            tokenizer is not None
+            and getattr(tokenizer, "eos_token_id", None) != pad_token_id
+        ):
             with torch.no_grad():
                 for name, module in model.named_modules():
                     if type(module) is torch.nn.Embedding:
-                        if getattr(module, "weight", None) is not None and getattr(module, "padding_idx", None) is not None:
-                            if module.padding_idx == pad_token_id and module.padding_idx < module.weight.shape[0]:
+                        if (
+                            getattr(module, "weight", None) is not None
+                            and getattr(module, "padding_idx", None) is not None
+                        ):
+                            if (
+                                module.padding_idx == pad_token_id
+                                and module.padding_idx < module.weight.shape[0]
+                            ):
                                 module.weight[module.padding_idx] = 0
         return model
-    pass
-
 
     @staticmethod
     def for_inference(model):
         if not hasattr(model, "parameters"):
-            raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_inference!")
+            raise TypeError(
+                "Unsloth: I think you're passing a tokenizer, not the model to for_inference!"
+            )
 
         def _for_inference(m):
-            if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = False
-            if hasattr(m, "training"): m.training = False
+            if hasattr(m, "gradient_checkpointing"):
+                m.gradient_checkpointing = False
+            if hasattr(m, "training"):
+                m.training = False
             # Pad tokenizer to the left
-            if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "left"
+            if hasattr(m, "_saved_temp_tokenizer"):
+                m._saved_temp_tokenizer.padding_side = "left"
             # Set a flag for generation!
             m._flag_for_generation = True
-        pass
+
         m = model
         while hasattr(m, "model"):
             _for_inference(m)
             m = m.model
         _for_inference(m)
-        model.eval() # to turn off training on modules deeper in
+        model.eval()  # to turn off training on modules deeper in
 
         # Since transformers 4.53, must turn off explicitly
         for module in model.modules():
             if hasattr(module, "gradient_checkpointing"):
                 module.gradient_checkpointing = False
-        pass
 
         # Also disable training for embeddings for NEFTune
         if hasattr(model, "get_input_embeddings"):
             embeddings = model.get_input_embeddings()
-            if hasattr(embeddings, "training"): embeddings.training = False
-        pass
+            if hasattr(embeddings, "training"):
+                embeddings.training = False
         if hasattr(model, "get_output_embeddings"):
             embeddings = model.get_output_embeddings()
-            if hasattr(embeddings, "training"): embeddings.training = False
-        pass
+            if hasattr(embeddings, "training"):
+                embeddings.training = False
         # Must disable returning hidden states in the case for GRPO
         os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
         # Must enable returning logits
@@ -1057,25 +1146,27 @@ def _for_inference(m):
         if torch_compiler_set_stance is not None:
             torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
         return model
-    pass
-
 
     @staticmethod
     def for_training(model, use_gradient_checkpointing = True):
         if not hasattr(model, "parameters"):
-            raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_training!")
+            raise TypeError(
+                "Unsloth: I think you're passing a tokenizer, not the model to for_training!"
+            )
 
         # Delete all fast inference loras
         for param in model.parameters():
             if hasattr(param, "_fast_lora"):
                 del param._fast_lora
-        pass
 
         def _for_training(m):
-            if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = use_gradient_checkpointing
-            if hasattr(m, "training"): m.training = True
+            if hasattr(m, "gradient_checkpointing"):
+                m.gradient_checkpointing = use_gradient_checkpointing
+            if hasattr(m, "training"):
+                m.training = True
             # Pad tokenizer to the left
-            if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "right"
+            if hasattr(m, "_saved_temp_tokenizer"):
+                m._saved_temp_tokenizer.padding_side = "right"
             # Set a flag for generation!
             if hasattr(m, "_flag_for_generation"):
                 try:
@@ -1083,34 +1174,31 @@ def _for_training(m):
                     del m._flag_for_generation
                 except:
                     pass
-        pass
+
         m = model
         while hasattr(m, "model"):
             _for_training(m)
             m = m.model
         _for_training(m)
-        model.train() # to turn on training on modules deeper in
+        model.train()  # to turn on training on modules deeper in
 
         # Since transformers 4.53, must turn on explicitly
         for module in model.modules():
             if hasattr(module, "gradient_checkpointing"):
                 module.gradient_checkpointing = True
-        pass
 
         # Also re-enable training for embeddings for NEFTune
         if hasattr(model, "get_input_embeddings"):
             embeddings = model.get_input_embeddings()
-            if hasattr(embeddings, "training"): embeddings.training = True
-        pass
+            if hasattr(embeddings, "training"):
+                embeddings.training = True
         if hasattr(model, "get_output_embeddings"):
             embeddings = model.get_output_embeddings()
-            if hasattr(embeddings, "training"): embeddings.training = True
-        pass
+            if hasattr(embeddings, "training"):
+                embeddings.training = True
         # Can re-enable not returning logits
         os.environ["UNSLOTH_RETURN_LOGITS"] = "0"
         # Turn off skip guards and set stance to default
         if torch_compiler_set_stance is not None:
             torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
         return model
-    pass
-pass
diff --git a/unsloth/ollama_template_mappers.py b/unsloth/ollama_template_mappers.py
index 1ac95f3a3..3ad0e334d 100644
--- a/unsloth/ollama_template_mappers.py
+++ b/unsloth/ollama_template_mappers.py
@@ -22,8 +22,7 @@
 
 # =========================================== Unsloth
 
-unsloth_ollama = \
-'''
+unsloth_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}{{ .System }}
 {{ end }}{{ if .Prompt }}>>> User: {{ .Prompt }}
@@ -36,12 +35,10 @@
 '''
 
 OLLAMA_TEMPLATES["unsloth"] = unsloth_ollama
-pass
 
 # =========================================== Zephyr
 
-zephyr_ollama = \
-'''
+zephyr_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|system|>
 {{ .System }}{__EOS_TOKEN__}
@@ -56,11 +53,9 @@
 '''
 
 OLLAMA_TEMPLATES["zephyr"] = zephyr_ollama
-pass
 
 # =========================================== ChatML
-chatml_ollama = \
-'''
+chatml_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|im_start|>system
 {{ .System }}<|im_end|>
@@ -76,14 +71,12 @@
 '''
 
 OLLAMA_TEMPLATES["chatml"] = chatml_ollama
-pass
 
 # =========================================== Mistral-1
 # Ollama from https://www.ollama.com/library/mistral
 # Mistral v0.1 https://ollama.com/library/mistral:v0.1/blobs/22e1b2e8dc2f
 # Mistral v0.2 https://ollama.com/library/mistral:v0.2/blobs/e6836092461f
-mistral_ollama = \
-'''
+mistral_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]"""
 PARAMETER stop "[INST]"
@@ -92,8 +85,7 @@
 
 # mistral:v0.3 https://ollama.com/library/mistral:v0.3/blobs/1ff5b64b61b9
 # mistral-large https://ollama.com/library/mistral-large:latest/blobs/96adabcf2c08
-mistral_v03_ollama = \
-'''
+mistral_v03_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- if .Messages }}
 {{- range $index, $_ := .Messages }}
@@ -123,8 +115,7 @@
 '''
 
 # Mistral-small https://ollama.com/library/mistral-small:latest/blobs/6db27cd4e277
-mistral_small_ollama = \
-'''
+mistral_small_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- range $index, $_ := .Messages }}
 {{- if eq .Role "system" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]
@@ -147,8 +138,7 @@
 '''
 
 # mistral-small-3.1 https://ollama.com/library/mistral-small3.1:latest/blobs/6db27cd4e277
-mistral_small_31_ollama = \
-'''
+mistral_small_31_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- range $index, $_ := .Messages }}
 {{- if eq .Role "system" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]
@@ -188,8 +178,7 @@
 '''
 
 # mistral-small-3.2 https://ollama.com/library/mistral-small3.2:latest/blobs/706c4d1164f7
-mistral_small_32_ollama = \
-'''
+mistral_small_32_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- range $index, $_ := .Messages }}
 {{- if eq .Role "system" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]
@@ -240,8 +229,7 @@
 
 
 # https://ollama.com/library/mixtral:latest/blobs/53d74de0d84c
-mixtral_ollama = \
-'''
+mixtral_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST] {{ .Response }}"""
 PARAMETER stop "[INST]"
@@ -249,8 +237,7 @@
 '''
 
 # https://registry.ollama.ai/library/mistral-nemo:latest/blobs/438402ddac75
-mistral_nemo_ollama = \
-'''
+mistral_nemo_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """
 {{- range $i, $_ := .Messages }}
@@ -273,8 +260,7 @@
 '''
 
 # https://ollama.com/library/codestral:latest/blobs/51707752a87c
-codestral_ollama = \
-'''
+codestral_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """
 {{- if .Suffix }}[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
@@ -301,8 +287,7 @@
 '''
 
 # https://ollama.com/library/devstral:latest/blobs/ea9ec42474e0
-devstral_ollama = \
-'''
+devstral_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- $lastUserIndex := -1 }}
 {{- range $index, $_ := .Messages }}
@@ -400,8 +385,7 @@
 '''
 
 # https://ollama.com/library/magistral:latest/blobs/35f7a1efc383
-magistral_ollama = \
-'''
+magistral_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """
 {{- range $i, $_ := .Messages }}
@@ -443,23 +427,21 @@
 Problem:"""
 '''
 
-OLLAMA_TEMPLATES["mistral"]          = mistral_ollama
-OLLAMA_TEMPLATES["mistral-v03"]      = mistral_v03_ollama
-OLLAMA_TEMPLATES["mistral-small"]    = mistral_small_ollama
+OLLAMA_TEMPLATES["mistral"] = mistral_ollama
+OLLAMA_TEMPLATES["mistral-v03"] = mistral_v03_ollama
+OLLAMA_TEMPLATES["mistral-small"] = mistral_small_ollama
 OLLAMA_TEMPLATES["mistral-small-31"] = mistral_small_31_ollama
 OLLAMA_TEMPLATES["mistral-small-32"] = mistral_small_32_ollama
-OLLAMA_TEMPLATES["mixtral"]          = mixtral_ollama
-OLLAMA_TEMPLATES["mistral-nemo"]     = mistral_nemo_ollama
-OLLAMA_TEMPLATES["devstral"]         = devstral_ollama
-OLLAMA_TEMPLATES["magistral"]        = magistral_ollama
-OLLAMA_TEMPLATES["codestral"]        = codestral_ollama
+OLLAMA_TEMPLATES["mixtral"] = mixtral_ollama
+OLLAMA_TEMPLATES["mistral-nemo"] = mistral_nemo_ollama
+OLLAMA_TEMPLATES["devstral"] = devstral_ollama
+OLLAMA_TEMPLATES["magistral"] = magistral_ollama
+OLLAMA_TEMPLATES["codestral"] = codestral_ollama
 
-pass
 
 # =========================================== Llama-2
 # Ollama from https://www.ollama.com/library/llama3
-llama_ollama = \
-'''
+llama_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """[INST] <>{{ .System }}<>
 
@@ -469,13 +451,11 @@
 PARAMETER min_p 0.1
 '''
 
-OLLAMA_TEMPLATES["llama"] =llama_ollama
-pass
+OLLAMA_TEMPLATES["llama"] = llama_ollama
 
 # ===========================================  Vicuna
 # Ollama from https://www.ollama.com/library/vicuna
-vicuna_ollama = \
-'''
+vicuna_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} {__EOS_TOKEN__}"""
 PARAMETER stop "{__EOS_TOKEN__}"
@@ -484,11 +464,9 @@
 '''
 
 OLLAMA_TEMPLATES["vicuna"] = vicuna_ollama
-pass
 
 # =========================================== Vicuna Old
-vicuna_old_ollama = \
-'''
+vicuna_old_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}{{ .System }}
 {{ end }}{{ if .Prompt }}### Human: {{ .Prompt }}
@@ -502,11 +480,9 @@
 
 OLLAMA_TEMPLATES["vicuna_old"] = vicuna_old_ollama
 OLLAMA_TEMPLATES["vicuna old"] = OLLAMA_TEMPLATES["vicuna_old"]
-pass
 
 # =========================================== Alpaca multi turn
-alpaca_ollama = \
-'''
+alpaca_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}{{ .System }}
 
@@ -524,12 +500,10 @@
 '''
 
 OLLAMA_TEMPLATES["alpaca"] = alpaca_ollama
-pass
 
 # =========================================== Gemma
 # Ollama from https://www.ollama.com/library/gemma
-gemma_ollama = \
-'''
+gemma_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """user
 {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}
@@ -545,11 +519,9 @@
 '''
 
 OLLAMA_TEMPLATES["gemma"] = gemma_ollama
-pass
 
 # =========================================== Gemma with ChatML instead
-gemma_chatml_ollama = \
-'''
+gemma_chatml_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|im_start|>system
 {{ .System }}<|im_end|>
@@ -567,7 +539,6 @@
 '''
 
 OLLAMA_TEMPLATES["gemma_chatml"] = gemma_chatml_ollama
-pass
 
 # =========================================== Gemma 2
 # Same as Gemma 1, but with sliding window attention!
@@ -578,12 +549,10 @@
 # =========================================== Gemma 2 with ChatML instead
 gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
 OLLAMA_TEMPLATES["gemma2_chatml"] = gemma2_chatml_ollama
-pass
 
 # =========================================== Llama-3
 # Ollama from https://www.ollama.com/library/llama3
-llama3_ollama = \
-'''
+llama3_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
 
@@ -602,13 +571,11 @@
 
 OLLAMA_TEMPLATES["llama-3"] = llama3_ollama
 OLLAMA_TEMPLATES["llama3"] = llama3_ollama
-pass
 
 
 # =========================================== Phi-3
 # Ollama from https://www.ollama.com/library/phi3
-phi3_ollama = \
-'''
+phi3_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|system|>
 {{ .System }}<|end|>
@@ -624,10 +591,9 @@
 PARAMETER min_p 0.1
 '''
 
-OLLAMA_TEMPLATES["phi-3"]   = phi3_ollama
-OLLAMA_TEMPLATES["phi-35"]  = OLLAMA_TEMPLATES["phi-3"]
+OLLAMA_TEMPLATES["phi-3"] = phi3_ollama
+OLLAMA_TEMPLATES["phi-35"] = OLLAMA_TEMPLATES["phi-3"]
 OLLAMA_TEMPLATES["phi-3.5"] = OLLAMA_TEMPLATES["phi-3"]
-pass
 
 # =========================================== Llama-3.1
 """
@@ -647,8 +613,7 @@
 """
 
 # Ollama from https://ollama.com/library/llama3.1 (needs updating!)
-llama31_ollama = \
-'''
+llama31_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .Messages }}
 {{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
@@ -708,8 +673,7 @@
 '''
 
 # https://ollama.com/ajindal/llama3.1-storm:8b/blobs/1970553b62f4
-llama_31_storm_ollama = \
-'''
+llama_31_storm_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """
 {{ if .Messages }}
@@ -764,8 +728,7 @@
 '''
 
 # https://ollama.com/library/nemotron:latest/blobs/4863fe3335f3
-llama_31_nemotron_ollama = \
-'''
+llama_31_nemotron_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """<|start_header_id|>system<|end_header_id|>
 
@@ -808,8 +771,7 @@
 '''
 
 # https://ollama.com/library/llama3.2-vision:latest/blobs/715415638c895a1f8e8c6
-llama_32_vision_ollama = \
-'''
+llama_32_vision_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- range $index, $_ := .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
 
@@ -823,20 +785,18 @@
 PARAMETER top_p 0.9
 '''
 
-OLLAMA_TEMPLATES["llama-3.1"]         = llama31_ollama
-OLLAMA_TEMPLATES["llama-31"]          = llama31_ollama
+OLLAMA_TEMPLATES["llama-3.1"] = llama31_ollama
+OLLAMA_TEMPLATES["llama-31"] = llama31_ollama
 OLLAMA_TEMPLATES["llama-31-nemotron"] = llama_31_nemotron_ollama
-OLLAMA_TEMPLATES["llama-31-storm"]    = llama_31_storm_ollama
-OLLAMA_TEMPLATES["llama-32-vision"]   = llama_32_vision_ollama
+OLLAMA_TEMPLATES["llama-31-storm"] = llama_31_storm_ollama
+OLLAMA_TEMPLATES["llama-32-vision"] = llama_32_vision_ollama
 
 for version in ("llama-3.2", "llama-3.3", "llama-32", "llama-33"):
     OLLAMA_TEMPLATES[version] = OLLAMA_TEMPLATES["llama-3.1"]
-pass
 
 # =========================================== tinyllama
 # tinyllama-chat https://ollama.com/library/tinyllama:latest/blobs/af0ddbdaaa26
-tinyllama_ollama = \
-'''
+tinyllama_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """<|system|>
 {{ .System }}
@@ -852,13 +812,11 @@
 
 OLLAMA_TEMPLATES["tinyllama"] = tinyllama_ollama
 
-pass
 
 # =========================================== Qwen 2/2.5
 # Qwen2 https://ollama.com/library/qwen2:latest/blobs/77c91b422cc9
 # Qwen2.5 from https://ollama.com/library/qwen2.5/blobs/eb4402837c78
-qwen25_ollama = \
-'''
+qwen25_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- if .Messages }}
 {{- if or .System .Tools }}<|im_start|>system
@@ -918,8 +876,7 @@
 '''
 
 # https://ollama.com/library/qwen2.5-coder:latest/blobs/1e65450c3067
-qwen_25_coder_ollama = \
-'''
+qwen_25_coder_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
 {{- else if .Messages }}
@@ -976,8 +933,7 @@
 '''
 
 # https://ollama.com/library/qwen2.5vl:latest/blobs/a242d8dfdc8f
-qwen_25_vl_ollama = \
-'''
+qwen_25_vl_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- if .System -}}
 <|im_start|>system
@@ -1003,8 +959,7 @@
 '''
 
 # https://ollama.com/library/openthinker:latest/blobs/32695b892af8
-openthinker_ollama = \
-'''
+openthinker_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- range $i, $_ := .Messages }}
 {{- $last := eq (len (slice $.Messages $i)) 1 -}}
@@ -1018,22 +973,21 @@
 '''
 
 
-OLLAMA_TEMPLATES["qwen-25"]       = qwen25_ollama
+OLLAMA_TEMPLATES["qwen-25"] = qwen25_ollama
 OLLAMA_TEMPLATES["qwen-25-coder"] = qwen_25_coder_ollama
-OLLAMA_TEMPLATES["qwen-25-vl"]    = qwen_25_vl_ollama
-OLLAMA_TEMPLATES["openthinker"]   = openthinker_ollama
-OLLAMA_TEMPLATES["qwen-2"]        = qwen25_ollama
-pass
+OLLAMA_TEMPLATES["qwen-25-vl"] = qwen_25_vl_ollama
+OLLAMA_TEMPLATES["openthinker"] = openthinker_ollama
+OLLAMA_TEMPLATES["qwen-2"] = qwen25_ollama
 
 # =========================================== Phi-4
-_phi4_ollama_template = \
-    "{{ if .System }}<|im_start|><|system|><|im_sep|>{{ .System }}<|im_end|>{{ end }}"\
-    "{{ if .Prompt }}<|im_start|><|user|><|im_sep|>{{ .Prompt }}<|im_end|>{{ end }}"\
+_phi4_ollama_template = (
+    "{{ if .System }}<|im_start|><|system|><|im_sep|>{{ .System }}<|im_end|>{{ end }}"
+    "{{ if .Prompt }}<|im_start|><|user|><|im_sep|>{{ .Prompt }}<|im_end|>{{ end }}"
     "<|im_start|><|assistant|><|im_sep|>{{ .Response }}<|im_end|>"
+)
 
 # Ollama from https://www.ollama.com/library/phi4 is different
-phi_4_ollama = \
-f'''
+phi_4_ollama = f'''
 FROM {{__FILE_LOCATION__}}
 TEMPLATE """{_phi4_ollama_template}"""
 PARAMETER stop "<|im_end|>"
@@ -1044,8 +998,7 @@
 '''
 
 # https://ollama.com/library/phi4-reasoning:latest/blobs/32695b892af8
-phi_4_reasoning_ollama = \
-'''
+phi_4_reasoning_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """
 {{- range $i, $_ := .Messages }}
@@ -1063,8 +1016,7 @@
 '''
 
 # https://ollama.com/library/phi4-mini:latest/blobs/813f53fdc6e5
-phi_4_mini_ollama = \
-'''
+phi_4_mini_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- if or .System .Tools }}<|system|>{{ if .System }}{{ .System }}{{ end }}
 {{- if .Tools }}{{ if not .System }}You are a helpful assistant with some tools.{{ end }}<|tool|>{{ .Tools }}<|/tool|><|end|>
@@ -1083,8 +1035,7 @@
 '''
 
 # https://ollama.com/library/phi4-mini-reasoning:latest/blobs/c895a1f8e8c6
-phi_4_mini_reasoning_ollama = \
-'''
+phi_4_mini_reasoning_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """
 {{- if .System }}<|system|>{{ .System }}
@@ -1099,17 +1050,15 @@
 {{- end }}"""
 SYSTEM """Your name is Phi, an AI math expert developed by Microsoft."""
 '''
-OLLAMA_TEMPLATES["phi-4"]                = phi_4_ollama
-OLLAMA_TEMPLATES["phi-4-reasoning"]      = phi_4_reasoning_ollama
-OLLAMA_TEMPLATES["phi-4-mini"]           = phi_4_mini_ollama
+OLLAMA_TEMPLATES["phi-4"] = phi_4_ollama
+OLLAMA_TEMPLATES["phi-4-reasoning"] = phi_4_reasoning_ollama
+OLLAMA_TEMPLATES["phi-4-mini"] = phi_4_mini_ollama
 OLLAMA_TEMPLATES["phi-4-mini-reasoning"] = phi_4_mini_reasoning_ollama
-pass
 
 
 # =========================================== Gemma-3
 # Ollama from https://ollama.com/library/gemma3/blobs/e0a42594d802
-gemma3_ollama = \
-'''
+gemma3_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- range $i, $_ := .Messages }}
 {{- $last := eq (len (slice $.Messages $i)) 1 }}
@@ -1132,8 +1081,7 @@
 '''
 
 # https://ollama.com/library/gemma3:270m/blobs/4b19ac7dd2fb
-gemma3_270m_ollama = \
-'''
+gemma3_270m_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- $systemPromptAdded := false }}
 {{- range $i, $_ := .Messages }}
@@ -1157,16 +1105,14 @@
 PARAMETER top_p 0.95
 '''
 
-OLLAMA_TEMPLATES["gemma-3"]     = gemma3_ollama
-OLLAMA_TEMPLATES["gemma3"]      = gemma3_ollama
+OLLAMA_TEMPLATES["gemma-3"] = gemma3_ollama
+OLLAMA_TEMPLATES["gemma3"] = gemma3_ollama
 OLLAMA_TEMPLATES["gemma3-270m"] = gemma3_270m_ollama
 
-pass
 
 # =========================================== Qwen-3
 # Ollama template for Qwen-3 (see https://ollama.com/library/qwen3/blobs/eb4402837c78)
-qwen3_ollama = \
-'''
+qwen3_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- if .Messages }}
 {{- if or .System .Tools }}<|im_start|>system
@@ -1231,12 +1177,10 @@
 OLLAMA_TEMPLATES["qwen-3"] = qwen3_ollama
 OLLAMA_TEMPLATES["qwen3"] = qwen3_ollama
 
-pass
 
 # =========================================== Gemma-3n
 # Ollama from https://ollama.com/library/gemma3n/blobs/e0a42594d802
-gemma3n_ollama = \
-'''
+gemma3n_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- range $i, $_ := .Messages }}
 {{- $last := eq (len (slice $.Messages $i)) 1 }}
@@ -1253,13 +1197,11 @@
 
 OLLAMA_TEMPLATES["gemma-3n"] = gemma3n_ollama
 OLLAMA_TEMPLATES["gemma3n"] = gemma3n_ollama
-pass
 
 # =========================================== GPT-OSS
 
 # Ollama from https://ollama.com/library/gpt-oss:latest/blobs/fa6710a93d78
-gptoss_ollama = \
-'''
+gptoss_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
 Knowledge cutoff: 2024-06
@@ -1440,15 +1382,13 @@
 '''
 
 OLLAMA_TEMPLATES["gpt-oss"] = gptoss_ollama
-OLLAMA_TEMPLATES["gptoss"]  = gptoss_ollama
+OLLAMA_TEMPLATES["gptoss"] = gptoss_ollama
 
-pass
 
 # =========================================== Qwen3
 
 # Ollama from https://ollama.com/library/qwen3/blobs/53e4ea15e8f5
-qwen3_ollama = \
-'''
+qwen3_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """
 {{- $lastUserIdx := -1 -}}
@@ -1507,15 +1447,12 @@
 OLLAMA_TEMPLATES["qwen3-instruct"] = qwen3_ollama
 OLLAMA_TEMPLATES["qwen3-thinking"] = qwen3_ollama
 
-pass
-
 
 # =========================================== Starling-LM
 
 
 # Ollama from https://ollama.com/library/starling-lm:7b/blobs/4b21bfc435b4
-starling_ollama = \
-'''
+starling_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>
 {{ end }}{{ if .Prompt }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>
@@ -1528,16 +1465,14 @@
 PARAMETER min_p 0.1
 '''
 
-OLLAMA_TEMPLATES["starling"] =  starling_ollama
+OLLAMA_TEMPLATES["starling"] = starling_ollama
 
-pass
 
 # =========================================== Yi-chat
 
 
 # Ollama from https://ollama.com/library/yi:34b-chat/blobs/62fbfd9ed093
-yi_chat_ollama = \
-'''
+yi_chat_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{ if .System }}<|im_start|>system
 {{ .System }}<|im_end|>
@@ -1552,8 +1487,7 @@
 # =========================================== Granite
 
 # Ollama from https://ollama.com/library/granite3.2:latest/blobs/3e7ca51acd6e
-granite_32_ollama = \
-'''
+granite_32_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- /*
 
@@ -1709,8 +1643,7 @@
 '''
 
 # granite-3.2-vision https://ollama.com/library/granite3.2-vision:latest/blobs/579046ba1157
-granite_32_vision_ollama = \
-'''
+granite_32_vision_ollama = '''
 FROM {__FILE_LOCATION__}
 TEMPLATE """{{- /* Tools */ -}}
 {{- if .Tools -}}
@@ -1769,11 +1702,9 @@
 SYSTEM """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
 '''
 
-OLLAMA_TEMPLATES["granite-32"]        = granite_32_ollama
+OLLAMA_TEMPLATES["granite-32"] = granite_32_ollama
 OLLAMA_TEMPLATES["granite-32-vision"] = granite_32_vision_ollama
 
-pass
-
 
 OLLAMA_TEMPLATE_TO_MODEL_MAPPER = {
     "phi-3.5": (
@@ -1827,7 +1758,7 @@
         "unsloth/mistral-7b-instruct-v0.2",
         "mistralai/Mistral-7B-Instruct-v0.2",
     ),
-    "mistral-v03":(
+    "mistral-v03": (
         "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
         "unsloth/mistral-7b-instruct-v0.3",
         "mistralai/Mistral-7B-Instruct-v0.3",
@@ -1855,7 +1786,7 @@
         "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
         "unsloth/Mistral-Small-3.2-24B-Instruct-2506-bnb-4bit",
     ),
-    "mixtral":(
+    "mixtral": (
         "unsloth/Mixtral-8x7B-Instruct-v0.1-unsloth-bnb-4bit",
         "unsloth/Mixtral-8x7B-Instruct-v0.1",
         "mistralai/Mixtral-8x7B-Instruct-v0.1",
@@ -1951,7 +1882,7 @@
         "unsloth/Llama-3.1-Storm-8B",
         "akjindal53244/Llama-3.1-Storm-8B",
     ),
-    "llama-31-nemotron":(
+    "llama-31-nemotron": (
         "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit",
         "unsloth/Llama-3.1-Nemotron-70B-Instruct",
         "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
@@ -1965,9 +1896,8 @@
         "unsloth/Llama-3.2-3B-Instruct",
         "meta-llama/Llama-3.2-3B-Instruct",
         "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
-
     ),
-    "llama-32-vision":(
+    "llama-32-vision": (
         "unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
         "unsloth/Llama-3.2-11B-Vision-Instruct",
         "meta-llama/Llama-3.2-11B-Vision-Instruct",
@@ -2040,7 +1970,7 @@
         "google/gemma-3n-E2B-it",
         "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
     ),
-    "gemma3-270m":(
+    "gemma3-270m": (
         "unsloth/gemma-3-270m-it-unsloth-bnb-4bit",
         "unsloth/gemma-3-270m-it",
         "google/gemma-3-270m-it",
@@ -2082,9 +2012,8 @@
         "unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit",
         "unsloth/Qwen2.5-Math-72B-Instruct",
         "Qwen/Qwen2.5-Math-72B-Instruct",
-
     ),
-    "qwen-25-coder":(
+    "qwen-25-coder": (
         "unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit",
         "unsloth/Qwen2.5-Coder-0.5B-Instruct",
         "Qwen/Qwen2.5-Coder-0.5B-Instruct",
@@ -2104,7 +2033,7 @@
         "unsloth/Qwen2.5-Coder-32B-Instruct",
         "Qwen/Qwen2.5-Coder-32B-Instruct",
     ),
-    "qwen-25-vl":(
+    "qwen-25-vl": (
         "unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit",
         "unsloth/Qwen2.5-VL-3B-Instruct",
         "Qwen/Qwen2.5-VL-3B-Instruct",
@@ -2256,11 +2185,8 @@
 for key, values in OLLAMA_TEMPLATE_TO_MODEL_MAPPER.items():
     for value in values:
         MODEL_TO_OLLAMA_TEMPLATE_MAPPER[value] = key
-    pass
 
     # Get lowercased
     lowered_key = key.lower()
     for value in values:
         MODEL_TO_OLLAMA_TEMPLATE_MAPPER[value.lower()] = lowered_key
-    pass
-pass
diff --git a/unsloth/registry/__init__.py b/unsloth/registry/__init__.py
index 587474369..52e6f5243 100644
--- a/unsloth/registry/__init__.py
+++ b/unsloth/registry/__init__.py
@@ -8,6 +8,7 @@
 
 _ARE_MODELS_REGISTERED = False
 
+
 def register_models():
     global _ARE_MODELS_REGISTERED
 
@@ -22,7 +23,15 @@ def register_models():
 
     _ARE_MODELS_REGISTERED = True
 
-def search_models(org: str = None, base_name: str = None, version: str = None, size: str = None, quant_types: list[QuantType] = None, search_pattern: str = None) -> list[ModelInfo]:
+
+def search_models(
+    org: str = None,
+    base_name: str = None,
+    version: str = None,
+    size: str = None,
+    quant_types: list[QuantType] = None,
+    search_pattern: str = None,
+) -> list[ModelInfo]:
     """
     Get model info from the registry.
 
@@ -33,19 +42,37 @@ def search_models(org: str = None, base_name: str = None, version: str = None, s
     """
     if not _ARE_MODELS_REGISTERED:
         register_models()
-    
+
     model_infos = MODEL_REGISTRY.values()
     if org:
-        model_infos = [model_info for model_info in model_infos if model_info.org == org]
+        model_infos = [
+            model_info for model_info in model_infos if model_info.org == org
+        ]
     if base_name:
-        model_infos = [model_info for model_info in model_infos if model_info.base_name == base_name]
+        model_infos = [
+            model_info
+            for model_info in model_infos
+            if model_info.base_name == base_name
+        ]
     if version:
-        model_infos = [model_info for model_info in model_infos if model_info.version == version]
+        model_infos = [
+            model_info for model_info in model_infos if model_info.version == version
+        ]
     if size:
-        model_infos = [model_info for model_info in model_infos if model_info.size == size]
+        model_infos = [
+            model_info for model_info in model_infos if model_info.size == size
+        ]
     if quant_types:
-        model_infos = [model_info for model_info in model_infos if any(model_info.quant_type == quant_type for quant_type in quant_types)]
+        model_infos = [
+            model_info
+            for model_info in model_infos
+            if any(model_info.quant_type == quant_type for quant_type in quant_types)
+        ]
     if search_pattern:
-        model_infos = [model_info for model_info in model_infos if search_pattern in model_info.model_path]
-    
-    return model_infos
\ No newline at end of file
+        model_infos = [
+            model_info
+            for model_info in model_infos
+            if search_pattern in model_info.model_path
+        ]
+
+    return model_infos
diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py
index 153a0e508..4bbc852cd 100644
--- a/unsloth/registry/_deepseek.py
+++ b/unsloth/registry/_deepseek.py
@@ -7,11 +7,15 @@
 _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = False
 _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = False
 
+
 class DeepseekV3ModelInfo(ModelInfo):
     @classmethod
     def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
         key = f"{base_name}-V{version}"
-        return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
+        return super().construct_model_name(
+            base_name, version, size, quant_type, instruct_tag, key
+        )
+
 
 class DeepseekR1ModelInfo(ModelInfo):
     @classmethod
@@ -19,135 +23,157 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag
         key = f"{base_name}-{version}" if version else base_name
         if size:
             key = f"{key}-{size}B"
-        return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
-    
+        return super().construct_model_name(
+            base_name, version, size, quant_type, instruct_tag, key
+        )
+
+
 # Deepseek V3 Model Meta
 DeepseekV3Meta = ModelMeta(
-    org="deepseek-ai",
-    base_name="DeepSeek",
-    instruct_tags=[None],
-    model_version="3",
-    model_sizes=[""],
-    model_info_cls=DeepseekV3ModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.BF16],
+    org = "deepseek-ai",
+    base_name = "DeepSeek",
+    instruct_tags = [None],
+    model_version = "3",
+    model_sizes = [""],
+    model_info_cls = DeepseekV3ModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.BF16],
 )
 
 DeepseekV3_0324Meta = ModelMeta(
-    org="deepseek-ai",
-    base_name="DeepSeek",
-    instruct_tags=[None],
-    model_version="3-0324",
-    model_sizes=[""],
-    model_info_cls=DeepseekV3ModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.GGUF],
+    org = "deepseek-ai",
+    base_name = "DeepSeek",
+    instruct_tags = [None],
+    model_version = "3-0324",
+    model_sizes = [""],
+    model_info_cls = DeepseekV3ModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.GGUF],
 )
 
 DeepseekR1Meta = ModelMeta(
-    org="deepseek-ai",
-    base_name="DeepSeek-R1",
-    instruct_tags=[None],
-    model_version="",
-    model_sizes=[""],
-    model_info_cls=DeepseekR1ModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.BF16, QuantType.GGUF],
+    org = "deepseek-ai",
+    base_name = "DeepSeek-R1",
+    instruct_tags = [None],
+    model_version = "",
+    model_sizes = [""],
+    model_info_cls = DeepseekR1ModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.BF16, QuantType.GGUF],
 )
 
 DeepseekR1ZeroMeta = ModelMeta(
-    org="deepseek-ai",
-    base_name="DeepSeek-R1",
-    instruct_tags=[None],
-    model_version="Zero",
-    model_sizes=[""],
-    model_info_cls=DeepseekR1ModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.GGUF],
+    org = "deepseek-ai",
+    base_name = "DeepSeek-R1",
+    instruct_tags = [None],
+    model_version = "Zero",
+    model_sizes = [""],
+    model_info_cls = DeepseekR1ModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.GGUF],
 )
 
 DeepseekR1DistillLlamaMeta = ModelMeta(
-    org="deepseek-ai",
-    base_name="DeepSeek-R1-Distill",
-    instruct_tags=[None],
-    model_version="Llama",
-    model_sizes=["8", "70"],
-    model_info_cls=DeepseekR1ModelInfo,
-    is_multimodal=False,
-    quant_types={"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]},
+    org = "deepseek-ai",
+    base_name = "DeepSeek-R1-Distill",
+    instruct_tags = [None],
+    model_version = "Llama",
+    model_sizes = ["8", "70"],
+    model_info_cls = DeepseekR1ModelInfo,
+    is_multimodal = False,
+    quant_types = {"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]},
 )
 
 # Deepseek R1 Distill Qwen Model Meta
 DeepseekR1DistillQwenMeta = ModelMeta(
-    org="deepseek-ai",
-    base_name="DeepSeek-R1-Distill",
-    instruct_tags=[None],
-    model_version="Qwen",
-    model_sizes=["1.5", "7", "14", "32"],
-    model_info_cls=DeepseekR1ModelInfo,
-    is_multimodal=False,
-    quant_types={
+    org = "deepseek-ai",
+    base_name = "DeepSeek-R1-Distill",
+    instruct_tags = [None],
+    model_version = "Qwen",
+    model_sizes = ["1.5", "7", "14", "32"],
+    model_info_cls = DeepseekR1ModelInfo,
+    is_multimodal = False,
+    quant_types = {
         "1.5": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],
         "7": [QuantType.UNSLOTH, QuantType.BNB],
         "14": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],
         "32": [QuantType.GGUF, QuantType.BNB],
     },
 )
-        
+
+
 def register_deepseek_v3_models(include_original_model: bool = False):
     global _IS_DEEPSEEK_V3_REGISTERED
     if _IS_DEEPSEEK_V3_REGISTERED:
         return
-    _register_models(DeepseekV3Meta, include_original_model=include_original_model)
+    _register_models(DeepseekV3Meta, include_original_model = include_original_model)
     _IS_DEEPSEEK_V3_REGISTERED = True
 
+
 def register_deepseek_v3_0324_models(include_original_model: bool = False):
     global _IS_DEEPSEEK_V3_0324_REGISTERED
     if _IS_DEEPSEEK_V3_0324_REGISTERED:
         return
-    _register_models(DeepseekV3_0324Meta, include_original_model=include_original_model)
+    _register_models(DeepseekV3_0324Meta, include_original_model = include_original_model)
     _IS_DEEPSEEK_V3_0324_REGISTERED = True
 
+
 def register_deepseek_r1_models(include_original_model: bool = False):
     global _IS_DEEPSEEK_R1_REGISTERED
     if _IS_DEEPSEEK_R1_REGISTERED:
         return
-    _register_models(DeepseekR1Meta, include_original_model=include_original_model)
+    _register_models(DeepseekR1Meta, include_original_model = include_original_model)
     _IS_DEEPSEEK_R1_REGISTERED = True
 
+
 def register_deepseek_r1_zero_models(include_original_model: bool = False):
     global _IS_DEEPSEEK_R1_ZERO_REGISTERED
     if _IS_DEEPSEEK_R1_ZERO_REGISTERED:
         return
-    _register_models(DeepseekR1ZeroMeta, include_original_model=include_original_model)
+    _register_models(DeepseekR1ZeroMeta, include_original_model = include_original_model)
     _IS_DEEPSEEK_R1_ZERO_REGISTERED = True
 
+
 def register_deepseek_r1_distill_llama_models(include_original_model: bool = False):
     global _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED
     if _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED:
         return
-    _register_models(DeepseekR1DistillLlamaMeta, include_original_model=include_original_model)
+    _register_models(
+        DeepseekR1DistillLlamaMeta, include_original_model = include_original_model
+    )
     _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = True
 
+
 def register_deepseek_r1_distill_qwen_models(include_original_model: bool = False):
     global _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED
     if _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED:
         return
-    _register_models(DeepseekR1DistillQwenMeta, include_original_model=include_original_model)
+    _register_models(
+        DeepseekR1DistillQwenMeta, include_original_model = include_original_model
+    )
     _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = True
 
+
 def register_deepseek_models(include_original_model: bool = False):
-    register_deepseek_v3_models(include_original_model=include_original_model)
-    register_deepseek_v3_0324_models(include_original_model=include_original_model)
-    register_deepseek_r1_models(include_original_model=include_original_model)
-    register_deepseek_r1_zero_models(include_original_model=include_original_model)
-    register_deepseek_r1_distill_llama_models(include_original_model=include_original_model)
-    register_deepseek_r1_distill_qwen_models(include_original_model=include_original_model)
+    register_deepseek_v3_models(include_original_model = include_original_model)
+    register_deepseek_v3_0324_models(include_original_model = include_original_model)
+    register_deepseek_r1_models(include_original_model = include_original_model)
+    register_deepseek_r1_zero_models(include_original_model = include_original_model)
+    register_deepseek_r1_distill_llama_models(
+        include_original_model = include_original_model
+    )
+    register_deepseek_r1_distill_qwen_models(
+        include_original_model = include_original_model
+    )
+
 
 def _list_deepseek_r1_distill_models():
     from unsloth.utils.hf_hub import ModelInfo as HfModelInfo
     from unsloth.utils.hf_hub import list_models
-    models: list[HfModelInfo] = list_models(author="unsloth", search="Distill", limit=1000)
+
+    models: list[HfModelInfo] = list_models(
+        author = "unsloth", search = "Distill", limit = 1000
+    )
     distill_models = []
     for model in models:
         model_id = model.id
@@ -159,14 +185,15 @@ def _list_deepseek_r1_distill_models():
     return distill_models
 
 
-register_deepseek_models(include_original_model=True)
+register_deepseek_models(include_original_model = True)
 
 if __name__ == "__main__":
     from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
+
     MODEL_REGISTRY.clear()
-    
-    register_deepseek_models(include_original_model=True)
-    
+
+    register_deepseek_models(include_original_model = True)
+
     for model_id, model_info in MODEL_REGISTRY.items():
         model_info = _check_model_info(model_id)
         if model_info is None:
@@ -176,4 +203,4 @@ def _list_deepseek_r1_distill_models():
     # distill_models = _list_deepseek_r1_distill_models()
     # for model in sorted(distill_models):
     #     if "qwen" in model.lower():
-    #         print(model)
\ No newline at end of file
+    #         print(model)
diff --git a/unsloth/registry/_gemma.py b/unsloth/registry/_gemma.py
index 9490c84f2..c338128bc 100644
--- a/unsloth/registry/_gemma.py
+++ b/unsloth/registry/_gemma.py
@@ -3,61 +3,69 @@
 _IS_GEMMA_3_BASE_REGISTERED = False
 _IS_GEMMA_3_INSTRUCT_REGISTERED = False
 
+
 class GemmaModelInfo(ModelInfo):
     @classmethod
     def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
         key = f"{base_name}-{version}-{size}B"
-        return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
+        return super().construct_model_name(
+            base_name, version, size, quant_type, instruct_tag, key
+        )
+
 
 # Gemma3 Base Model Meta
 GemmaMeta3Base = ModelMeta(
-    org="google",
-    base_name="gemma",
-    instruct_tags=["pt"],  # pt = base
-    model_version="3",
-    model_sizes=["1", "4", "12", "27"],
-    model_info_cls=GemmaModelInfo,
-    is_multimodal=True,
-    quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
+    org = "google",
+    base_name = "gemma",
+    instruct_tags = ["pt"],  # pt = base
+    model_version = "3",
+    model_sizes = ["1", "4", "12", "27"],
+    model_info_cls = GemmaModelInfo,
+    is_multimodal = True,
+    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
 )
 
 # Gemma3 Instruct Model Meta
 GemmaMeta3Instruct = ModelMeta(
-    org="google",
-    base_name="gemma",
-    instruct_tags=["it"],  # it = instruction tuned
-    model_version="3",
-    model_sizes=["1", "4", "12", "27"],
-    model_info_cls=GemmaModelInfo,
-    is_multimodal=True,
-    quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
+    org = "google",
+    base_name = "gemma",
+    instruct_tags = ["it"],  # it = instruction tuned
+    model_version = "3",
+    model_sizes = ["1", "4", "12", "27"],
+    model_info_cls = GemmaModelInfo,
+    is_multimodal = True,
+    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
 )
 
+
 def register_gemma_3_base_models(include_original_model: bool = False):
     global _IS_GEMMA_3_BASE_REGISTERED
     if _IS_GEMMA_3_BASE_REGISTERED:
         return
-    _register_models(GemmaMeta3Base, include_original_model=include_original_model)
+    _register_models(GemmaMeta3Base, include_original_model = include_original_model)
     _IS_GEMMA_3_BASE_REGISTERED = True
 
+
 def register_gemma_3_instruct_models(include_original_model: bool = False):
     global _IS_GEMMA_3_INSTRUCT_REGISTERED
     if _IS_GEMMA_3_INSTRUCT_REGISTERED:
         return
-    _register_models(GemmaMeta3Instruct, include_original_model=include_original_model)
+    _register_models(GemmaMeta3Instruct, include_original_model = include_original_model)
     _IS_GEMMA_3_INSTRUCT_REGISTERED = True
 
+
 def register_gemma_models(include_original_model: bool = False):
-    register_gemma_3_base_models(include_original_model=include_original_model)
-    register_gemma_3_instruct_models(include_original_model=include_original_model)
+    register_gemma_3_base_models(include_original_model = include_original_model)
+    register_gemma_3_instruct_models(include_original_model = include_original_model)
 
 
 if __name__ == "__main__":
     from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
+
     MODEL_REGISTRY.clear()
-    
-    register_gemma_models(include_original_model=True)
-    
+
+    register_gemma_models(include_original_model = True)
+
     for model_id, model_info in MODEL_REGISTRY.items():
         model_info = _check_model_info(model_id)
         if model_info is None:
diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py
index f1b9dbdd3..f5f82372b 100644
--- a/unsloth/registry/_llama.py
+++ b/unsloth/registry/_llama.py
@@ -9,62 +9,66 @@ class LlamaModelInfo(ModelInfo):
     @classmethod
     def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
         key = f"{base_name}-{version}-{size}B"
-        return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
+        return super().construct_model_name(
+            base_name, version, size, quant_type, instruct_tag, key
+        )
 
 
 class LlamaVisionModelInfo(ModelInfo):
     @classmethod
     def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
         key = f"{base_name}-{version}-{size}B-Vision"
-        return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
+        return super().construct_model_name(
+            base_name, version, size, quant_type, instruct_tag, key
+        )
 
 
 # Llama 3.1
 LlamaMeta_3_1 = ModelMeta(
-    org="meta-llama",
-    base_name="Llama",
-    instruct_tags=[None, "Instruct"],
-    model_version="3.1",
-    model_sizes=["8"],
-    model_info_cls=LlamaModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
+    org = "meta-llama",
+    base_name = "Llama",
+    instruct_tags = [None, "Instruct"],
+    model_version = "3.1",
+    model_sizes = ["8"],
+    model_info_cls = LlamaModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
 )
 
 # Llama 3.2 Base Models
 LlamaMeta_3_2_Base = ModelMeta(
-    org="meta-llama",
-    base_name="Llama",
-    instruct_tags=[None],
-    model_version="3.2",
-    model_sizes=["1", "3"],
-    model_info_cls=LlamaModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
+    org = "meta-llama",
+    base_name = "Llama",
+    instruct_tags = [None],
+    model_version = "3.2",
+    model_sizes = ["1", "3"],
+    model_info_cls = LlamaModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
 )
 
 # Llama 3.2 Instruction Tuned Models
 LlamaMeta_3_2_Instruct = ModelMeta(
-    org="meta-llama",
-    base_name="Llama",
-    instruct_tags=["Instruct"],
-    model_version="3.2",
-    model_sizes=["1", "3"],
-    model_info_cls=LlamaModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
+    org = "meta-llama",
+    base_name = "Llama",
+    instruct_tags = ["Instruct"],
+    model_version = "3.2",
+    model_sizes = ["1", "3"],
+    model_info_cls = LlamaModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
 )
 
 # Llama 3.2 Vision
 LlamaMeta_3_2_Vision = ModelMeta(
-    org="meta-llama",
-    base_name="Llama",
-    instruct_tags=[None, "Instruct"],
-    model_version="3.2",
-    model_sizes=["11", "90"],
-    model_info_cls=LlamaVisionModelInfo,
-    is_multimodal=True,
-    quant_types={
+    org = "meta-llama",
+    base_name = "Llama",
+    instruct_tags = [None, "Instruct"],
+    model_version = "3.2",
+    model_sizes = ["11", "90"],
+    model_info_cls = LlamaVisionModelInfo,
+    is_multimodal = True,
+    quant_types = {
         "11": [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
         "90": [QuantType.NONE],
     },
@@ -75,35 +79,43 @@ def register_llama_3_1_models(include_original_model: bool = False):
     global _IS_LLAMA_3_1_REGISTERED
     if _IS_LLAMA_3_1_REGISTERED:
         return
-    _register_models(LlamaMeta_3_1, include_original_model=include_original_model)
+    _register_models(LlamaMeta_3_1, include_original_model = include_original_model)
     _IS_LLAMA_3_1_REGISTERED = True
 
+
 def register_llama_3_2_models(include_original_model: bool = False):
     global _IS_LLAMA_3_2_REGISTERED
     if _IS_LLAMA_3_2_REGISTERED:
         return
-    _register_models(LlamaMeta_3_2_Base, include_original_model=include_original_model)
-    _register_models(LlamaMeta_3_2_Instruct, include_original_model=include_original_model)
+    _register_models(LlamaMeta_3_2_Base, include_original_model = include_original_model)
+    _register_models(
+        LlamaMeta_3_2_Instruct, include_original_model = include_original_model
+    )
     _IS_LLAMA_3_2_REGISTERED = True
 
+
 def register_llama_3_2_vision_models(include_original_model: bool = False):
     global _IS_LLAMA_3_2_VISION_REGISTERED
     if _IS_LLAMA_3_2_VISION_REGISTERED:
         return
-    _register_models(LlamaMeta_3_2_Vision, include_original_model=include_original_model)
+    _register_models(
+        LlamaMeta_3_2_Vision, include_original_model = include_original_model
+    )
     _IS_LLAMA_3_2_VISION_REGISTERED = True
 
 
 def register_llama_models(include_original_model: bool = False):
-    register_llama_3_1_models(include_original_model=include_original_model)
-    register_llama_3_2_models(include_original_model=include_original_model)
-    register_llama_3_2_vision_models(include_original_model=include_original_model)
+    register_llama_3_1_models(include_original_model = include_original_model)
+    register_llama_3_2_models(include_original_model = include_original_model)
+    register_llama_3_2_vision_models(include_original_model = include_original_model)
+
 
 if __name__ == "__main__":
     from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
+
     MODEL_REGISTRY.clear()
 
-    register_llama_models(include_original_model=True)
+    register_llama_models(include_original_model = True)
 
     for model_id, model_info in MODEL_REGISTRY.items():
         model_info = _check_model_info(model_id)
diff --git a/unsloth/registry/_mistral.py b/unsloth/registry/_mistral.py
index 44cd1e764..173d6cfde 100644
--- a/unsloth/registry/_mistral.py
+++ b/unsloth/registry/_mistral.py
@@ -6,7 +6,8 @@
 
 _MISTRAL_SMALL_03_25_VERSION = "2503"
 _MISTRAL_SMALL_01_25_VERSION = "2501"
-_MISTRAL_SMALL_09_24_VERSION = "2409" # Not uploaded to unsloth
+_MISTRAL_SMALL_09_24_VERSION = "2409"  # Not uploaded to unsloth
+
 
 class MistralSmallModelInfo(ModelInfo):
     @classmethod
@@ -17,24 +18,29 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag
             key = f"{base_name}-{size}B-{instruct_tag}"
         key += f"-{version}"
         key = cls.append_quant_type(key, quant_type)
-        
+
         return key
 
 
 MistralSmall_2503_Base_Meta = ModelMeta(
-    org="mistralai",
-    base_name="Mistral-Small",
-    instruct_tags=["Base"],
-    model_version=_MISTRAL_SMALL_03_25_VERSION,
-    model_sizes=["24"],
-    model_info_cls=MistralSmallModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB],
+    org = "mistralai",
+    base_name = "Mistral-Small",
+    instruct_tags = ["Base"],
+    model_version = _MISTRAL_SMALL_03_25_VERSION,
+    model_sizes = ["24"],
+    model_info_cls = MistralSmallModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB],
 )
 
 MistralSmall_2503_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta)
 MistralSmall_2503_Instruct_Meta.instruct_tags = ["Instruct"]
-MistralSmall_2503_Instruct_Meta.quant_types = [QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF]
+MistralSmall_2503_Instruct_Meta.quant_types = [
+    QuantType.NONE,
+    QuantType.UNSLOTH,
+    QuantType.BNB,
+    QuantType.GGUF,
+]
 
 MistralSmall_2501_Base_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta)
 MistralSmall_2501_Base_Meta.model_version = _MISTRAL_SMALL_01_25_VERSION
@@ -42,29 +48,41 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag
 MistralSmall_2501_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Instruct_Meta)
 MistralSmall_2501_Instruct_Meta.model_version = _MISTRAL_SMALL_01_25_VERSION
 
+
 def register_mistral_small_models(include_original_model: bool = False):
     global _IS_MISTRAL_SMALL_REGISTERED
     if _IS_MISTRAL_SMALL_REGISTERED:
         return
-    _register_models(MistralSmall_2503_Base_Meta, include_original_model=include_original_model)
-    _register_models(MistralSmall_2503_Instruct_Meta, include_original_model=include_original_model)
-    _register_models(MistralSmall_2501_Base_Meta, include_original_model=include_original_model)
-    _register_models(MistralSmall_2501_Instruct_Meta, include_original_model=include_original_model)
+    _register_models(
+        MistralSmall_2503_Base_Meta, include_original_model = include_original_model
+    )
+    _register_models(
+        MistralSmall_2503_Instruct_Meta, include_original_model = include_original_model
+    )
+    _register_models(
+        MistralSmall_2501_Base_Meta, include_original_model = include_original_model
+    )
+    _register_models(
+        MistralSmall_2501_Instruct_Meta, include_original_model = include_original_model
+    )
 
     _IS_MISTRAL_SMALL_REGISTERED = True
 
+
 def register_mistral_models(include_original_model: bool = False):
-    register_mistral_small_models(include_original_model=include_original_model)
+    register_mistral_small_models(include_original_model = include_original_model)
+
 
 if __name__ == "__main__":
     from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
+
     MODEL_REGISTRY.clear()
-    
-    register_mistral_models(include_original_model=True)
-    
+
+    register_mistral_models(include_original_model = True)
+
     for model_id, model_info in MODEL_REGISTRY.items():
         model_info = _check_model_info(model_id)
         if model_info is None:
             print(f"\u2718 {model_id}")
         else:
-            print(f"\u2713 {model_id}")    
\ No newline at end of file
+            print(f"\u2713 {model_id}")
diff --git a/unsloth/registry/_phi.py b/unsloth/registry/_phi.py
index d06ec8d37..a6f773c48 100644
--- a/unsloth/registry/_phi.py
+++ b/unsloth/registry/_phi.py
@@ -3,63 +3,72 @@
 _IS_PHI_4_REGISTERED = False
 _IS_PHI_4_INSTRUCT_REGISTERED = False
 
+
 class PhiModelInfo(ModelInfo):
     @classmethod
     def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
         key = f"{base_name}-{version}"
-        return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
+        return super().construct_model_name(
+            base_name, version, size, quant_type, instruct_tag, key
+        )
+
 
 # Phi Model Meta
 PhiMeta4 = ModelMeta(
-    org="microsoft",
-    base_name="phi",
-    instruct_tags=[None],
-    model_version="4",
-    model_sizes=["1"],  # Assuming only one size
-    model_info_cls=PhiModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
+    org = "microsoft",
+    base_name = "phi",
+    instruct_tags = [None],
+    model_version = "4",
+    model_sizes = ["1"],  # Assuming only one size
+    model_info_cls = PhiModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
 )
 
 # Phi Instruct Model Meta
 PhiInstructMeta4 = ModelMeta(
-    org="microsoft",
-    base_name="phi",
-    instruct_tags=["mini-instruct"],
-    model_version="4",
-    model_sizes=["1"],  # Assuming only one size
-    model_info_cls=PhiModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
+    org = "microsoft",
+    base_name = "phi",
+    instruct_tags = ["mini-instruct"],
+    model_version = "4",
+    model_sizes = ["1"],  # Assuming only one size
+    model_info_cls = PhiModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
 )
 
+
 def register_phi_4_models(include_original_model: bool = False):
     global _IS_PHI_4_REGISTERED
     if _IS_PHI_4_REGISTERED:
         return
-    _register_models(PhiMeta4, include_original_model=include_original_model)
+    _register_models(PhiMeta4, include_original_model = include_original_model)
     _IS_PHI_4_REGISTERED = True
 
+
 def register_phi_4_instruct_models(include_original_model: bool = False):
     global _IS_PHI_4_INSTRUCT_REGISTERED
     if _IS_PHI_4_INSTRUCT_REGISTERED:
         return
-    _register_models(PhiInstructMeta4, include_original_model=include_original_model)
+    _register_models(PhiInstructMeta4, include_original_model = include_original_model)
     _IS_PHI_4_INSTRUCT_REGISTERED = True
 
+
 def register_phi_models(include_original_model: bool = False):
-    register_phi_4_models(include_original_model=include_original_model)
-    register_phi_4_instruct_models(include_original_model=include_original_model)
+    register_phi_4_models(include_original_model = include_original_model)
+    register_phi_4_instruct_models(include_original_model = include_original_model)
+
 
 if __name__ == "__main__":
     from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
+
     MODEL_REGISTRY.clear()
-    
-    register_phi_models(include_original_model=True)
-    
+
+    register_phi_models(include_original_model = True)
+
     for model_id, model_info in MODEL_REGISTRY.items():
         model_info = _check_model_info(model_id)
         if model_info is None:
             print(f"\u2718 {model_id}")
         else:
-            print(f"\u2713 {model_id}") 
\ No newline at end of file
+            print(f"\u2713 {model_id}")
diff --git a/unsloth/registry/_qwen.py b/unsloth/registry/_qwen.py
index 4417515a7..f852cb876 100644
--- a/unsloth/registry/_qwen.py
+++ b/unsloth/registry/_qwen.py
@@ -3,112 +3,131 @@
 _IS_QWEN_2_5_REGISTERED = False
 _IS_QWEN_2_5_VL_REGISTERED = False
 _IS_QWEN_QWQ_REGISTERED = False
+
+
 class QwenModelInfo(ModelInfo):
     @classmethod
     def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
         key = f"{base_name}{version}-{size}B"
-        return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
+        return super().construct_model_name(
+            base_name, version, size, quant_type, instruct_tag, key
+        )
 
 
 class QwenVLModelInfo(ModelInfo):
     @classmethod
     def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
         key = f"{base_name}{version}-VL-{size}B"
-        return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
+        return super().construct_model_name(
+            base_name, version, size, quant_type, instruct_tag, key
+        )
+
 
 class QwenQwQModelInfo(ModelInfo):
     @classmethod
     def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
         key = f"{base_name}-{size}B"
-        return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
-    
+        return super().construct_model_name(
+            base_name, version, size, quant_type, instruct_tag, key
+        )
+
+
 class QwenQVQPreviewModelInfo(ModelInfo):
     @classmethod
     def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
         key = f"{base_name}-{size}B-Preview"
-        return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
-    
+        return super().construct_model_name(
+            base_name, version, size, quant_type, instruct_tag, key
+        )
+
+
 # Qwen2.5 Model Meta
 Qwen_2_5_Meta = ModelMeta(
-    org="Qwen",
-    base_name="Qwen",
-    instruct_tags=[None, "Instruct"],
-    model_version="2.5",
-    model_sizes=["3", "7"],
-    model_info_cls=QwenModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
+    org = "Qwen",
+    base_name = "Qwen",
+    instruct_tags = [None, "Instruct"],
+    model_version = "2.5",
+    model_sizes = ["3", "7"],
+    model_info_cls = QwenModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
 )
 
 # Qwen2.5 VL Model Meta
 Qwen_2_5_VLMeta = ModelMeta(
-    org="Qwen",
-    base_name="Qwen",
-    instruct_tags=["Instruct"],  # No base, only instruction tuned
-    model_version="2.5",
-    model_sizes=["3", "7", "32", "72"],
-    model_info_cls=QwenVLModelInfo,
-    is_multimodal=True,
-    quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
+    org = "Qwen",
+    base_name = "Qwen",
+    instruct_tags = ["Instruct"],  # No base, only instruction tuned
+    model_version = "2.5",
+    model_sizes = ["3", "7", "32", "72"],
+    model_info_cls = QwenVLModelInfo,
+    is_multimodal = True,
+    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
 )
 
 # Qwen QwQ Model Meta
 QwenQwQMeta = ModelMeta(
-    org="Qwen",
-    base_name="QwQ",
-    instruct_tags=[None],
-    model_version="",
-    model_sizes=["32"],
-    model_info_cls=QwenQwQModelInfo,
-    is_multimodal=False,
-    quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
+    org = "Qwen",
+    base_name = "QwQ",
+    instruct_tags = [None],
+    model_version = "",
+    model_sizes = ["32"],
+    model_info_cls = QwenQwQModelInfo,
+    is_multimodal = False,
+    quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
 )
 
 # Qwen QVQ Preview Model Meta
 QwenQVQPreviewMeta = ModelMeta(
-    org="Qwen",
-    base_name="QVQ",
-    instruct_tags=[None],
-    model_version="",
-    model_sizes=["72"],
-    model_info_cls=QwenQVQPreviewModelInfo,
-    is_multimodal=True,
-    quant_types=[QuantType.NONE, QuantType.BNB],
+    org = "Qwen",
+    base_name = "QVQ",
+    instruct_tags = [None],
+    model_version = "",
+    model_sizes = ["72"],
+    model_info_cls = QwenQVQPreviewModelInfo,
+    is_multimodal = True,
+    quant_types = [QuantType.NONE, QuantType.BNB],
 )
 
+
 def register_qwen_2_5_models(include_original_model: bool = False):
     global _IS_QWEN_2_5_REGISTERED
     if _IS_QWEN_2_5_REGISTERED:
         return
-    _register_models(Qwen_2_5_Meta, include_original_model=include_original_model)
+    _register_models(Qwen_2_5_Meta, include_original_model = include_original_model)
     _IS_QWEN_2_5_REGISTERED = True
 
+
 def register_qwen_2_5_vl_models(include_original_model: bool = False):
     global _IS_QWEN_2_5_VL_REGISTERED
     if _IS_QWEN_2_5_VL_REGISTERED:
         return
-    _register_models(Qwen_2_5_VLMeta, include_original_model=include_original_model)
+    _register_models(Qwen_2_5_VLMeta, include_original_model = include_original_model)
     _IS_QWEN_2_5_VL_REGISTERED = True
 
+
 def register_qwen_qwq_models(include_original_model: bool = False):
     global _IS_QWEN_QWQ_REGISTERED
     if _IS_QWEN_QWQ_REGISTERED:
         return
-    _register_models(QwenQwQMeta, include_original_model=include_original_model)
-    _register_models(QwenQVQPreviewMeta, include_original_model=include_original_model)
+    _register_models(QwenQwQMeta, include_original_model = include_original_model)
+    _register_models(QwenQVQPreviewMeta, include_original_model = include_original_model)
     _IS_QWEN_QWQ_REGISTERED = True
 
+
 def register_qwen_models(include_original_model: bool = False):
-    register_qwen_2_5_models(include_original_model=include_original_model)
-    register_qwen_2_5_vl_models(include_original_model=include_original_model)
-    register_qwen_qwq_models(include_original_model=include_original_model)
+    register_qwen_2_5_models(include_original_model = include_original_model)
+    register_qwen_2_5_vl_models(include_original_model = include_original_model)
+    register_qwen_qwq_models(include_original_model = include_original_model)
+
 
 if __name__ == "__main__":
     from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
+
     MODEL_REGISTRY.clear()
-    
-    register_qwen_models(include_original_model=True)
-    
+
+    register_qwen_models(include_original_model = True)
+
     for model_id, model_info in MODEL_REGISTRY.items():
         model_info = _check_model_info(model_id)
         if model_info is None:
diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py
index 590beebee..945301420 100644
--- a/unsloth/registry/registry.py
+++ b/unsloth/registry/registry.py
@@ -5,10 +5,11 @@
 
 class QuantType(Enum):
     BNB = "bnb"
-    UNSLOTH = "unsloth" # dynamic 4-bit quantization
+    UNSLOTH = "unsloth"  # dynamic 4-bit quantization
     GGUF = "GGUF"
     NONE = "none"
-    BF16 = "bf16" # only for Deepseek V3
+    BF16 = "bf16"  # only for Deepseek V3
+
 
 # Tags for Hugging Face model paths
 BNB_QUANTIZED_TAG = "bnb-4bit"
@@ -22,7 +23,8 @@ class QuantType(Enum):
     QuantType.GGUF: GGUF_TAG,
     QuantType.NONE: None,
     QuantType.BF16: BF16_TAG,
-} 
+}
+
 
 # NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH
 @dataclass
@@ -53,15 +55,15 @@ def append_instruct_tag(key: str, instruct_tag: str = None):
         return key
 
     @staticmethod
-    def append_quant_type(
-        key: str, quant_type: QuantType = None
-    ):
+    def append_quant_type(key: str, quant_type: QuantType = None):
         if quant_type != QuantType.NONE:
             key = "-".join([key, QUANT_TAG_MAP[quant_type]])
         return key
 
     @classmethod
-    def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag, key=""):
+    def construct_model_name(
+        cls, base_name, version, size, quant_type, instruct_tag, key = ""
+    ):
         key = cls.append_instruct_tag(key, instruct_tag)
         key = cls.append_quant_type(key, quant_type)
         return key
@@ -79,9 +81,11 @@ class ModelMeta:
     base_name: str
     model_version: str
     model_info_cls: type[ModelInfo]
-    model_sizes: list[str] = field(default_factory=list)
-    instruct_tags: list[str] = field(default_factory=list)
-    quant_types: list[QuantType] | dict[str, list[QuantType]] = field(default_factory=list)
+    model_sizes: list[str] = field(default_factory = list)
+    instruct_tags: list[str] = field(default_factory = list)
+    quant_types: list[QuantType] | dict[str, list[QuantType]] = field(
+        default_factory = list
+    )
     is_multimodal: bool = False
 
 
@@ -100,26 +104,28 @@ def register_model(
     name: str = None,
 ):
     name = name or model_info_cls.construct_model_name(
-        base_name=base_name,
-        version=version,
-        size=size,
-        quant_type=quant_type,
-        instruct_tag=instruct_tag,
+        base_name = base_name,
+        version = version,
+        size = size,
+        quant_type = quant_type,
+        instruct_tag = instruct_tag,
     )
     key = f"{org}/{name}"
 
     if key in MODEL_REGISTRY:
-        raise ValueError(f"Model {key} already registered, current keys: {MODEL_REGISTRY.keys()}")
+        raise ValueError(
+            f"Model {key} already registered, current keys: {MODEL_REGISTRY.keys()}"
+        )
 
     MODEL_REGISTRY[key] = model_info_cls(
-        org=org,
-        base_name=base_name,
-        version=version,
-        size=size,
-        is_multimodal=is_multimodal,
-        instruct_tag=instruct_tag,
-        quant_type=quant_type,
-        name=name,
+        org = org,
+        base_name = base_name,
+        version = version,
+        size = size,
+        is_multimodal = is_multimodal,
+        instruct_tag = instruct_tag,
+        quant_type = quant_type,
+        name = name,
     )
 
 
@@ -131,7 +137,7 @@ def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]):
     api = HfApi()
 
     try:
-        model_info: HfModelInfo = api.model_info(model_id, expand=properties)
+        model_info: HfModelInfo = api.model_info(model_id, expand = properties)
     except Exception as e:
         if isinstance(e, RepositoryNotFoundError):
             warnings.warn(f"{model_id} not found on Hugging Face")
@@ -160,26 +166,26 @@ def _register_models(model_meta: ModelMeta, include_original_model: bool = False
                 _quant_types = quant_types
             for quant_type in _quant_types:
                 # NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH
-                _org = "unsloth" # unsloth models -- these are all quantized versions of the original model
+                _org = "unsloth"  # unsloth models -- these are all quantized versions of the original model
                 register_model(
-                    model_info_cls=model_info_cls,
-                    org=_org,
-                    base_name=base_name,
-                    version=model_version,
-                    size=size,
-                    instruct_tag=instruct_tag,
-                    quant_type=quant_type,
-                    is_multimodal=is_multimodal,
+                    model_info_cls = model_info_cls,
+                    org = _org,
+                    base_name = base_name,
+                    version = model_version,
+                    size = size,
+                    instruct_tag = instruct_tag,
+                    quant_type = quant_type,
+                    is_multimodal = is_multimodal,
                 )
             # include original model from releasing organization
             if include_original_model:
                 register_model(
-                    model_info_cls=model_info_cls,
-                    org=org,
-                    base_name=base_name,
-                    version=model_version,
-                    size=size,
-                    instruct_tag=instruct_tag,
-                    quant_type=QuantType.NONE,
-                    is_multimodal=is_multimodal,
+                    model_info_cls = model_info_cls,
+                    org = org,
+                    base_name = base_name,
+                    version = model_version,
+                    size = size,
+                    instruct_tag = instruct_tag,
+                    quant_type = QuantType.NONE,
+                    is_multimodal = is_multimodal,
                 )
diff --git a/unsloth/save.py b/unsloth/save.py
index 4ca59bbda..ba1724b2d 100644
--- a/unsloth/save.py
+++ b/unsloth/save.py
@@ -15,7 +15,14 @@
 from unsloth_zoo.utils import Version
 from importlib.metadata import version as importlib_version
 from unsloth_zoo.hf_utils import dtype_from_config, HAS_TORCH_DTYPE
-from unsloth_zoo.llama_cpp import convert_to_gguf, quantize_gguf, use_local_gguf, install_llama_cpp, check_llama_cpp, _download_convert_hf_to_gguf
+from unsloth_zoo.llama_cpp import (
+    convert_to_gguf,
+    quantize_gguf,
+    use_local_gguf,
+    install_llama_cpp,
+    check_llama_cpp,
+    _download_convert_hf_to_gguf,
+)
 from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
 from peft.tuners.lora import Linear4bit as Peft_Linear4bit
 from peft.tuners.lora import Linear as Peft_Linear
@@ -38,6 +45,7 @@
 from .ollama_template_mappers import OLLAMA_TEMPLATES, MODEL_TO_OLLAMA_TEMPLATE_MAPPER
 from transformers import ProcessorMixin
 from huggingface_hub import HfApi
+
 try:
     from huggingface_hub import get_token
 except:
@@ -46,8 +54,6 @@
     except:
         # For older versions of huggingface_hub
         from huggingface_hub.utils._token import get_token
-    pass
-pass
 from pathlib import Path
 from peft import PeftModelForCausalLM, PeftModel
 
@@ -60,67 +66,80 @@
 ]
 
 # llama.cpp specific targets - all takes 90s. Below takes 60s
-LLAMA_CPP_TARGETS = ["llama-quantize", "llama-export-lora", "llama-cli",]
+LLAMA_CPP_TARGETS = [
+    "llama-quantize",
+    "llama-export-lora",
+    "llama-cli",
+]
 
 # Check environments
 keynames = "\n" + "\n".join(os.environ.keys())
-IS_COLAB_ENVIRONMENT  = "\nCOLAB_"  in keynames
+IS_COLAB_ENVIRONMENT = "\nCOLAB_" in keynames
 IS_KAGGLE_ENVIRONMENT = "\nKAGGLE_" in keynames
 KAGGLE_TMP = "/tmp"
 del keynames
 
 # Weights
 LLAMA_WEIGHTS = (
-    "self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.o_proj",
-    "mlp.gate_proj", "mlp.up_proj", "mlp.down_proj",
+    "self_attn.q_proj",
+    "self_attn.k_proj",
+    "self_attn.v_proj",
+    "self_attn.o_proj",
+    "mlp.gate_proj",
+    "mlp.up_proj",
+    "mlp.down_proj",
 )
 LLAMA_LAYERNORMS = (
-    "input_layernorm", "post_attention_layernorm",
-    "pre_feedforward_layernorm", "post_feedforward_layernorm",
-    "self_attn.q_norm", "self_attn.k_norm",
+    "input_layernorm",
+    "post_attention_layernorm",
+    "pre_feedforward_layernorm",
+    "post_feedforward_layernorm",
+    "self_attn.q_norm",
+    "self_attn.k_norm",
 )
 
 # https://github.com/ggerganov/llama.cpp/blob/master/examples/quantize/quantize.cpp#L19
 # From https://mlabonne.github.io/blog/posts/Quantize_Llama_2_models_using_ggml.html
-ALLOWED_QUANTS = \
-{
-    "not_quantized"  : "Recommended. Fast conversion. Slow inference, big files.",
-    "fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
-    "quantized"      : "Recommended. Slow conversion. Fast inference, small files.",
-    "f32"     : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
-    "bf16"    : "Bfloat16 - Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
-    "f16"     : "Float16  - Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
-    "q8_0"    : "Fast conversion. High resource use, but generally acceptable.",
-    "q4_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
-    "q5_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
-    "q2_k"    : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
-    "q3_k_l"  : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
-    "q3_k_m"  : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
-    "q3_k_s"  : "Uses Q3_K for all tensors",
-    "q4_0"    : "Original quant method, 4-bit.",
-    "q4_1"    : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
-    "q4_k_s"  : "Uses Q4_K for all tensors",
-    "q4_k"    : "alias for q4_k_m",
-    "q5_k"    : "alias for q5_k_m",
-    "q5_0"    : "Higher accuracy, higher resource usage and slower inference.",
-    "q5_1"    : "Even higher accuracy, resource usage and slower inference.",
-    "q5_k_s"  : "Uses Q5_K for all tensors",
-    "q6_k"    : "Uses Q8_K for all tensors",
+ALLOWED_QUANTS = {
+    "not_quantized": "Recommended. Fast conversion. Slow inference, big files.",
+    "fast_quantized": "Recommended. Fast conversion. OK inference, OK file size.",
+    "quantized": "Recommended. Slow conversion. Fast inference, small files.",
+    "f32": "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
+    "bf16": "Bfloat16 - Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
+    "f16": "Float16  - Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
+    "q8_0": "Fast conversion. High resource use, but generally acceptable.",
+    "q4_k_m": "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
+    "q5_k_m": "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
+    "q2_k": "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
+    "q3_k_l": "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+    "q3_k_m": "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+    "q3_k_s": "Uses Q3_K for all tensors",
+    "q4_0": "Original quant method, 4-bit.",
+    "q4_1": "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
+    "q4_k_s": "Uses Q4_K for all tensors",
+    "q4_k": "alias for q4_k_m",
+    "q5_k": "alias for q5_k_m",
+    "q5_0": "Higher accuracy, higher resource usage and slower inference.",
+    "q5_1": "Even higher accuracy, resource usage and slower inference.",
+    "q5_k_s": "Uses Q5_K for all tensors",
+    "q6_k": "Uses Q8_K for all tensors",
     # "iq2_xxs" : "2.06 bpw quantization", # Not supported sadly
     # "iq2_xs"  : "2.31 bpw quantization",
     # "iq3_xxs" : "3.06 bpw quantization",
-    "q3_k_xs" : "3-bit extra small quantization",
+    "q3_k_xs": "3-bit extra small quantization",
 }
 
+
 def print_quantization_methods():
     for key, value in ALLOWED_QUANTS.items():
         print(f'"{key}"  ==> {value}')
-    pass
-pass
 
 
-def check_if_sentencepiece_model(model, temporary_location = "_unsloth_sentencepiece_temp"):
-    if not hasattr(model, "_saved_temp_tokenizer"): return False
+def check_if_sentencepiece_model(
+    model, temporary_location = "_unsloth_sentencepiece_temp"
+):
+    if not hasattr(model, "_saved_temp_tokenizer"):
+        return False
 
     temp_tokenizer = model._saved_temp_tokenizer
     sentencepiece_model = False
@@ -129,19 +148,17 @@ def check_if_sentencepiece_model(model, temporary_location = "_unsloth_sentencep
     if not os.path.exists(file_location):
         created_folder = True
         os.makedirs(file_location)
-    pass
     temp_tokenizer.save_pretrained(file_location)
     if os.path.isfile(f"{file_location}/tokenizer.model"):
         sentencepiece_model = True
-    pass
     if created_folder:
         shutil.rmtree(file_location, ignore_errors = True)
     return sentencepiece_model
-pass
 
 
 def _free_cached_model(model):
     from huggingface_hub import scan_cache_dir
+
     cached_repos = list(scan_cache_dir().repos)
 
     # Go through every cached repo, and delete the one that matches the model we want to save.
@@ -149,27 +166,27 @@ def _free_cached_model(model):
     for cached_repo in cached_repos:
         if cached_repo.repo_id == model.config._name_or_path:
             remove_cache_commit = list(cached_repo.revisions)[0].commit_hash
-            delete_strategy = scan_cache_dir().delete_revisions(remove_cache_commit,)
+            delete_strategy = scan_cache_dir().delete_revisions(
+                remove_cache_commit,
+            )
 
             logger.warning_once(
-                "Unsloth: Will remove a cached repo with size " + \
-                delete_strategy.expected_freed_size_str,
+                "Unsloth: Will remove a cached repo with size "
+                + delete_strategy.expected_freed_size_str,
             )
 
             delete_strategy.execute()
-        pass
-    pass
-pass
 
 
 def _merge_lora(layer, name):
-
     bias = getattr(layer, "bias", None)
     if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)):
         # Is LoRA so we need to merge!
         W, quant_state, A, B, s, bias = get_lora_parameters_bias(layer)
         if quant_state is not None:
-            dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2]
+            dtype = (
+                quant_state.dtype if type(quant_state) is not list else quant_state[2]
+            )
             W = fast_dequantize(W, quant_state)
         else:
             dtype = W.dtype
@@ -184,13 +201,13 @@ def _merge_lora(layer, name):
             # if not torch.isfinite(W).all():
             maximum_element = torch.max(W.min().abs(), W.max())
             if not torch.isfinite(maximum_element).item():
-                raise ValueError(f"Unsloth: Merge failed.\n{name} has some elements = infinity.")
-        pass
+                raise ValueError(
+                    f"Unsloth: Merge failed.\n{name} has some elements = infinity."
+                )
         W = W.t().to(dtype)
     else:
         W = layer.weight
     return W, bias
-pass
 
 
 def fast_save_pickle(shard, name):
@@ -204,41 +221,40 @@ def fast_save_pickle(shard, name):
         # pickle_protocol = pickle.HIGHEST_PROTOCOL,
     )
     return
-pass
 
 
 @torch.inference_mode
 def unsloth_save_model(
     model,
     tokenizer,
-    save_directory       : Union[str, os.PathLike],
-    save_method          : str = "lora", # ["lora", "merged_16bit", "merged_4bit"]
-    push_to_hub          : bool = False,
-    token                : Optional[Union[str, bool]] = None,
-    is_main_process      : bool = True,
-    state_dict           : Optional[dict] = None,
-    save_function        : Callable = torch.save,
-    max_shard_size       : Union[int, str] = "5GB",
-    safe_serialization   : bool = True,
-    variant              : Optional[str] = None,
-    save_peft_format     : bool = True,
-
+    save_directory: Union[str, os.PathLike],
+    save_method: str = "lora",  # ["lora", "merged_16bit", "merged_4bit"]
+    push_to_hub: bool = False,
+    token: Optional[Union[str, bool]] = None,
+    is_main_process: bool = True,
+    state_dict: Optional[dict] = None,
+    save_function: Callable = torch.save,
+    max_shard_size: Union[int, str] = "5GB",
+    safe_serialization: bool = True,
+    variant: Optional[str] = None,
+    save_peft_format: bool = True,
     # Push to hub
-    use_temp_dir         : Optional[bool] = None,
-    commit_message       : Optional[str] = "Trained with Unsloth",
-    private              : Optional[bool] = None,
-    create_pr            : bool = False,
-    revision             : str = None,
-    commit_description   : str = "Upload model trained with Unsloth 2x faster",
-    tags                 : List[str] = None,
-
+    use_temp_dir: Optional[bool] = None,
+    commit_message: Optional[str] = "Trained with Unsloth",
+    private: Optional[bool] = None,
+    create_pr: bool = False,
+    revision: str = None,
+    commit_description: str = "Upload model trained with Unsloth 2x faster",
+    tags: List[str] = None,
     # Our functions
-    temporary_location   : str = "_unsloth_temporary_saved_buffers",
-    maximum_memory_usage : float = 0.9,
+    temporary_location: str = "_unsloth_temporary_saved_buffers",
+    maximum_memory_usage: float = 0.9,
 ):
-    if token is None: token = get_token()
+    if token is None:
+        token = get_token()
 
-    if commit_message is None: commit_message = ""
+    if commit_message is None:
+        commit_message = ""
     if "Unsloth" not in commit_message:
         commit_message += " (Trained with Unsloth)"
     commit_message = commit_message.lstrip()
@@ -247,185 +263,211 @@ def unsloth_save_model(
         commit_description = "Upload model trained with Unsloth 2x faster"
     elif "Unsloth 2x faster" not in commit_description:
         commit_description += " (Trained with Unsloth 2x faster)"
-    pass
 
     if save_method == "merged_4bit":
         raise RuntimeError(
-            "Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\n"\
-            "to merge to GGUF or others later on. I suggest you to do this as a final step\n"\
-            "if you're planning to do multiple saves.\n"\
+            "Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\n"
+            "to merge to GGUF or others later on. I suggest you to do this as a final step\n"
+            "if you're planning to do multiple saves.\n"
             "If you are certain, change `save_method` to `merged_4bit_forced`."
         )
     elif save_method == "merged_4bit_forced":
         save_method = "merged_4bit"
-    pass
 
     save_pretrained_settings = dict(locals())
-    for deletion in ("model", "tokenizer", "save_method", "temporary_location", "maximum_memory_usage"):
+    for deletion in (
+        "model",
+        "tokenizer",
+        "save_method",
+        "temporary_location",
+        "maximum_memory_usage",
+    ):
         del save_pretrained_settings[deletion]
-    pass
 
     # First check for a token!
     if push_to_hub:
         from huggingface_hub import whoami
+
         try:
             username = whoami(token = token)["name"]
         except:
             raise RuntimeError(
-                "Unsloth: Please supply a token!\n"\
+                "Unsloth: Please supply a token!\n"
                 "Go to https://huggingface.co/settings/tokens"
             )
-        pass
-    pass
 
-    assert(maximum_memory_usage > 0 and maximum_memory_usage <= 0.95)
+    assert maximum_memory_usage > 0 and maximum_memory_usage <= 0.95
 
     # Clean memory up first
     for _ in range(3):
         torch.cuda.empty_cache()
         gc.collect()
-    pass
 
     save_method = save_method.lower().replace(" ", "_")
-    if save_method != "lora" and save_method != "merged_16bit" and save_method != "merged_4bit":
+    if (
+        save_method != "lora"
+        and save_method != "merged_16bit"
+        and save_method != "merged_4bit"
+    ):
         raise RuntimeError(
-            "Unsloth: You must select one of 3 options when saving models:\n"\
-            '"lora"         ==> This is the fastest and easiet. Just saves LoRA modules.\n'\
-            '"merged_16bit" ==> This merges LoRA weights and saves to float16. Needed for llama.cpp / GGUF.\n'\
+            "Unsloth: You must select one of 3 options when saving models:\n"
+            '"lora"         ==> This is the fastest and easiet. Just saves LoRA modules.\n'
+            '"merged_16bit" ==> This merges LoRA weights and saves to float16. Needed for llama.cpp / GGUF.\n'
             '"merged_4bit"  ==> This merges LoRA weights and saves to 4bit. Useful for DPO / inference.'
         )
-    pass
 
     if save_method == "merged_4bit":
-
         print("Unsloth: Merging 4bit and LoRA weights to 4bit...")
         print("This might take 5 minutes...")
 
         # Counteract no LoRA adapters!
         if hasattr(model, "merge_and_unload"):
             model = model.merge_and_unload()
-        pass
         print("Done.")
-    pass
 
     if tags is not None:
-        assert(isinstance(tags, (list, tuple)))
-        tags = list(tags) + ["unsloth",]
+        assert isinstance(tags, (list, tuple))
+        tags = list(tags) + [
+            "unsloth",
+        ]
     else:
-        tags = ["unsloth",]
-    pass
+        tags = [
+            "unsloth",
+        ]
     save_pretrained_settings["tags"] = tags
 
     if ((save_method == "lora") or (save_method == "merged_4bit")) and push_to_hub:
         if token is None:
             raise RuntimeError(
-                "Unsloth: Pushing to HF requires a token. Pass `token = 'hf_....'`\n"\
+                "Unsloth: Pushing to HF requires a token. Pass `token = 'hf_....'`\n"
                 "Go to https://huggingface.co/settings/tokens."
             )
-        pass
 
         if save_method == "lora":
             print("Unsloth: Saving LoRA adapters. Please wait...")
         elif save_method == "merged_4bit":
             print("Unsloth: Saving 4bit Bitsandbytes model. Please wait...")
-        pass
 
         # Update model tag
         _ = upload_to_huggingface(
-            model, save_directory, token,
-            "finetuned", "trl", file_location = None,
-            old_username = None, private = private,
+            model,
+            save_directory,
+            token,
+            "finetuned",
+            "trl",
+            file_location = None,
+            old_username = None,
+            private = private,
         )
 
-        getattr(model, "original_push_to_hub", model.push_to_hub)\
-        (
-            repo_id            = save_directory,
-            use_temp_dir       = use_temp_dir,
-            commit_message     = commit_message,
-            private            = private,
-            token              = token,
-            max_shard_size     = max_shard_size,
-            create_pr          = create_pr,
+        getattr(model, "original_push_to_hub", model.push_to_hub)(
+            repo_id = save_directory,
+            use_temp_dir = use_temp_dir,
+            commit_message = commit_message,
+            private = private,
+            token = token,
+            max_shard_size = max_shard_size,
+            create_pr = create_pr,
             safe_serialization = safe_serialization,
-            revision           = revision,
+            revision = revision,
             commit_description = commit_description,
-            tags               = tags,
+            tags = tags,
         )
         if tokenizer is not None:
             # Set padding side to left for inference
             old_padding_side = tokenizer.padding_side
             tokenizer.padding_side = "left"
 
-            getattr(tokenizer, "original_push_to_hub", tokenizer.push_to_hub)\
-            (
-                repo_id            = save_directory,
-                use_temp_dir       = use_temp_dir,
-                commit_message     = commit_message,
-                private            = private,
-                token              = token,
-                max_shard_size     = max_shard_size,
-                create_pr          = create_pr,
+            getattr(tokenizer, "original_push_to_hub", tokenizer.push_to_hub)(
+                repo_id = save_directory,
+                use_temp_dir = use_temp_dir,
+                commit_message = commit_message,
+                private = private,
+                token = token,
+                max_shard_size = max_shard_size,
+                create_pr = create_pr,
                 safe_serialization = safe_serialization,
-                revision           = revision,
+                revision = revision,
                 commit_description = commit_description,
-                tags               = tags,
+                tags = tags,
             )
 
             # Revert back padding side
             tokenizer.padding_side = old_padding_side
-        pass
 
         if hasattr(model, "config"):
-            print(f"Saved {save_method} model to https://huggingface.co/" + save_directory)
-        pass
+            print(
+                f"Saved {save_method} model to https://huggingface.co/" + save_directory
+            )
         return save_directory, None
-    pass
 
     # Tokenizer has different saving arguments
-    tokenizer_save_settings = \
-    {
-        "save_directory"  : save_pretrained_settings["save_directory"],
-        "legacy_format"   : None,
-        "filename_prefix" : None,
-        "push_to_hub"     : save_pretrained_settings["push_to_hub"],
-        "private"         : save_pretrained_settings["private"],
-        "token"           : save_pretrained_settings["token"],
+    tokenizer_save_settings = {
+        "save_directory": save_pretrained_settings["save_directory"],
+        "legacy_format": None,
+        "filename_prefix": None,
+        "push_to_hub": save_pretrained_settings["push_to_hub"],
+        "private": save_pretrained_settings["private"],
+        "token": save_pretrained_settings["token"],
     }
 
     # Check if PEFT Model or not - if yes, 3 levels. If not 2 levels.
     from peft import PeftModelForCausalLM
+
     if isinstance(model, PeftModelForCausalLM):
         internal_model = model.model
     else:
         internal_model = model
-    pass
 
     # Cannot be converted properly!
-    if (save_method == "merged_4bit") or (save_method == "lora") or (
-        not hasattr(model, "model") or \
-        not hasattr(internal_model.model, "layers")
+    if (
+        (save_method == "merged_4bit")
+        or (save_method == "lora")
+        or (not hasattr(model, "model") or not hasattr(internal_model.model, "layers"))
     ):
         # Do general saving
         # Edit save_pretrained_settings
         # [TODO] _create_repo has errors due to **kwargs getting accepted
         # commit_description does not seem to work?
-        what_to_delete = ("use_temp_dir", "commit_message", "create_pr", "revision", "commit_description", "tags",) \
-            if save_pretrained_settings["push_to_hub"] is False else \
-            ("use_temp_dir", "create_pr", "revision", "tags", "commit_description",)
+        what_to_delete = (
+            (
+                "use_temp_dir",
+                "commit_message",
+                "create_pr",
+                "revision",
+                "commit_description",
+                "tags",
+            )
+            if save_pretrained_settings["push_to_hub"] is False
+            else (
+                "use_temp_dir",
+                "create_pr",
+                "revision",
+                "tags",
+                "commit_description",
+            )
+        )
         for deletion in what_to_delete:
             del save_pretrained_settings[deletion]
-        pass
         if hasattr(model, "add_model_tags"):
-            model.add_model_tags(["unsloth",])
+            model.add_model_tags(
+                [
+                    "unsloth",
+                ]
+            )
 
         # Update model tag
         if push_to_hub:
-             _ = upload_to_huggingface(
-                model, save_pretrained_settings["save_directory"], token,
-                "finetuned", "trl", file_location = None,
-                old_username = None, private = private,
+            _ = upload_to_huggingface(
+                model,
+                save_pretrained_settings["save_directory"],
+                token,
+                "finetuned",
+                "trl",
+                file_location = None,
+                old_username = None,
+                private = private,
             )
-        pass
 
         if tokenizer is not None:
             print("Unsloth: Saving tokenizer...", end = "")
@@ -444,47 +486,48 @@ def unsloth_save_model(
             print()
 
         print("Unsloth: Saving model...", end = "")
-        if save_method != "lora": print(" This might take 10 minutes for Llama-7b...", end = "")
+        if save_method != "lora":
+            print(" This might take 10 minutes for Llama-7b...", end = "")
 
         # [TODO] Is this correct?
         if save_method == "lora":
             save_pretrained_settings["selected_adapters"] = None
-        pass
 
         model.save_pretrained(**save_pretrained_settings)
 
         if push_to_hub and hasattr(model, "config"):
-            print("Saved to https://huggingface.co/" + save_pretrained_settings["save_directory"])
-        pass
+            print(
+                "Saved to https://huggingface.co/"
+                + save_pretrained_settings["save_directory"]
+            )
 
         print(" Done.")
         return save_directory, None
-    pass
 
     # If push_to_hub, we must remove the .../ part of a repo
     username = None
     if push_to_hub and "/" in save_directory:
-
         # +1 solves absolute path issues
         new_save_directory = save_directory
-        username = new_save_directory[:new_save_directory.find("/")]
-        new_save_directory = new_save_directory[new_save_directory.find("/")+1:]
+        username = new_save_directory[: new_save_directory.find("/")]
+        new_save_directory = new_save_directory[new_save_directory.find("/") + 1 :]
         if IS_KAGGLE_ENVIRONMENT:
-            new_save_directory = os.path.join(KAGGLE_TMP, new_save_directory[new_save_directory.find("/")+1:])
+            new_save_directory = os.path.join(
+                KAGGLE_TMP, new_save_directory[new_save_directory.find("/") + 1 :]
+            )
             logger.warning_once(
-                "Unsloth: You are pushing to hub in Kaggle environment.\n"\
+                "Unsloth: You are pushing to hub in Kaggle environment.\n"
                 f"To save memory, we shall move {save_directory} to {new_save_directory}"
             )
         else:
             logger.warning_once(
-                f"Unsloth: You are pushing to hub, but you passed your HF username = {username}.\n"\
+                f"Unsloth: You are pushing to hub, but you passed your HF username = {username}.\n"
                 f"We shall truncate {save_directory} to {new_save_directory}"
             )
 
         save_pretrained_settings["save_directory"] = new_save_directory
-        tokenizer_save_settings ["save_directory"] = new_save_directory
+        tokenizer_save_settings["save_directory"] = new_save_directory
         save_directory = new_save_directory
-    pass
 
     print("Unsloth: Merging 4bit and LoRA weights to 16bit...")
 
@@ -492,18 +535,25 @@ def unsloth_save_model(
     max_ram = psutil.virtual_memory().available
     sharded_ram_usage = 5 * 1024 * 1024 * 1024
     if type(max_shard_size) is str:
-        gb_found = re.match(r"([0-9]{1,})[\s]{0,}GB", max_shard_size, flags = re.IGNORECASE)
-        mb_found = re.match(r"([0-9]{1,})[\s]{0,}MB", max_shard_size, flags = re.IGNORECASE)
-        if   gb_found: sharded_ram_usage = int(gb_found.group(1)) * 1024 * 1024 * 1024
-        elif mb_found: sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024
+        gb_found = re.match(
+            r"([0-9]{1,})[\s]{0,}GB", max_shard_size, flags = re.IGNORECASE
+        )
+        mb_found = re.match(
+            r"([0-9]{1,})[\s]{0,}MB", max_shard_size, flags = re.IGNORECASE
+        )
+        if gb_found:
+            sharded_ram_usage = int(gb_found.group(1)) * 1024 * 1024 * 1024
+        elif mb_found:
+            sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024
     elif type(max_shard_size) is int:
         sharded_ram_usage = sharded_ram_usage
-    pass
 
     # Switch to our fast saving modules if it's a slow PC!
     n_cpus = psutil.cpu_count(logical = False)
-    if n_cpus is None: n_cpus = psutil.cpu_count()
-    if n_cpus is None: n_cpus = 1
+    if n_cpus is None:
+        n_cpus = psutil.cpu_count()
+    if n_cpus is None:
+        n_cpus = 1
 
     if safe_serialization is None:
         safe_serialization = True
@@ -511,27 +561,27 @@ def unsloth_save_model(
 
     elif safe_serialization and (n_cpus <= 2):
         logger.warning_once(
-            f"Unsloth: You have {n_cpus} CPUs. Using `safe_serialization` is 10x slower.\n"\
-            f"We shall switch to Pytorch saving, which might take 3 minutes and not 30 minutes.\n"\
+            f"Unsloth: You have {n_cpus} CPUs. Using `safe_serialization` is 10x slower.\n"
+            f"We shall switch to Pytorch saving, which might take 3 minutes and not 30 minutes.\n"
             f"To force `safe_serialization`, set it to `None` instead.",
         )
         safe_serialization = False
         save_function = fast_save_pickle
         save_pretrained_settings["safe_serialization"] = safe_serialization
-        save_pretrained_settings["save_function"]      = save_function
-    pass
+        save_pretrained_settings["save_function"] = save_function
 
     # Only safe_serialization uses more RAM
     if safe_serialization:
         max_ram -= sharded_ram_usage
     else:
-        max_ram -= sharded_ram_usage*0.25 # Uses much less
-    pass
+        max_ram -= sharded_ram_usage * 0.25  # Uses much less
 
     max_ram = int(max(0, max_ram) * maximum_memory_usage)
-    print(f"Unsloth: Will use up to "\
-          f"{round(max_ram/1024/1024/1024, 2)} out of "\
-          f"{round(psutil.virtual_memory().total/1024/1024/1024, 2)} RAM for saving.")
+    print(
+        f"Unsloth: Will use up to "
+        f"{round(max_ram / 1024 / 1024 / 1024, 2)} out of "
+        f"{round(psutil.virtual_memory().total / 1024 / 1024 / 1024, 2)} RAM for saving."
+    )
 
     # Move temporary_location to /tmp in Kaggle
     if IS_KAGGLE_ENVIRONMENT:
@@ -540,36 +590,41 @@ def unsloth_save_model(
     # Max directory for disk saving
     if not os.path.exists(temporary_location):
         os.makedirs(temporary_location)
-    pass
 
     # Check if Kaggle or Colab, since only 20GB of Disk space allowed.
     if IS_KAGGLE_ENVIRONMENT or IS_COLAB_ENVIRONMENT:
         # We free up 4GB of space
         logger.warning_once(
-            "Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded\n"\
+            "Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded\n"
             "model which will save 4-16GB of disk space, allowing you to save on Kaggle/Colab."
         )
         _free_cached_model(internal_model)
-    pass
 
     # HF also uses a OrderedDict
     from collections import OrderedDict
+
     state_dict = OrderedDict()
 
     torch_dtype = dtype_from_config(internal_model.config)
     if type(torch_dtype) is str:
-        if   torch_dtype ==  "float16": torch_dtype = torch.float16
-        elif torch_dtype == "bfloat16": torch_dtype = torch.bfloat16
-    pass
+        if torch_dtype == "float16":
+            torch_dtype = torch.float16
+        elif torch_dtype == "bfloat16":
+            torch_dtype = torch.bfloat16
 
     # Check modules to save float32 dtype
-    state_dict["model.embed_tokens.weight"] = internal_model.model.embed_tokens.weight.data.to(torch_dtype)
+    state_dict["model.embed_tokens.weight"] = (
+        internal_model.model.embed_tokens.weight.data.to(torch_dtype)
+    )
 
-    max_vram = int(torch.cuda.get_device_properties(0).total_memory * maximum_memory_usage)
+    max_vram = int(
+        torch.cuda.get_device_properties(0).total_memory * maximum_memory_usage
+    )
 
     print("Unsloth: Saving model... This might take 5 minutes ...")
 
     from tqdm import tqdm as ProgressBar
+
     for j, layer in enumerate(ProgressBar(internal_model.model.layers)):
         for item in LLAMA_WEIGHTS:
             proj = eval(f"layer.{item}")
@@ -579,7 +634,6 @@ def unsloth_save_model(
             # Bias term
             if bias is not None:
                 state_dict[f"model.layers.{j}.{item}.bias"] = bias
-            pass
 
             if (torch.cuda.memory_allocated() + W.nbytes) < max_vram:
                 # Save to GPU memory
@@ -594,70 +648,103 @@ def unsloth_save_model(
                 # Save to Disk
                 logger.warning_once("\nWe will save to Disk and not RAM now.")
                 filename = os.path.join(temporary_location, f"{name}.pt")
-                torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
+                torch.save(
+                    W,
+                    filename,
+                    pickle_module = pickle,
+                    pickle_protocol = pickle.HIGHEST_PROTOCOL,
+                )
                 # weights_only = True weirdly fails?
-                state_dict[name] = torch.load(filename, map_location = "cpu", mmap = True, weights_only = False)
-        pass
+                state_dict[name] = torch.load(
+                    filename, map_location = "cpu", mmap = True, weights_only = False
+                )
         for item in LLAMA_LAYERNORMS:
             try:
                 # Skip for Gemma 2
-                state_dict[f"model.layers.{j}.{item}.weight"] = eval(f"layer.{item}.weight.data")
+                state_dict[f"model.layers.{j}.{item}.weight"] = eval(
+                    f"layer.{item}.weight.data"
+                )
             except:
                 continue
-        pass
-    pass
 
     state_dict["model.norm.weight"] = internal_model.model.norm.weight.data
     # Check for modules_to_save float32 dtype
 
     # Check for tied weights
-    if internal_model.model.embed_tokens.weight.data_ptr() != internal_model.lm_head.weight.data_ptr():
-        state_dict["lm_head.weight"] = internal_model.lm_head.weight.data.to(torch_dtype)
-    pass
+    if (
+        internal_model.model.embed_tokens.weight.data_ptr()
+        != internal_model.lm_head.weight.data_ptr()
+    ):
+        state_dict["lm_head.weight"] = internal_model.lm_head.weight.data.to(
+            torch_dtype
+        )
 
     # All tensors MUST be type torch.Tensor and not torch.nn.parameter.Parameter
     for key, value in state_dict.items():
-        if hasattr(value, "data"): state_dict[key] = value = value.data
+        if hasattr(value, "data"):
+            state_dict[key] = value = value.data
         if type(value) is not torch.Tensor:
             logger.warning_once(f"Unsloth: {key} is not a Tensor but a {type(value)}.")
-        pass
-    pass
 
     # Edit save_pretrained_settings
     # [TODO] _create_repo has errors due to **kwargs getting accepted
     save_pretrained_settings["state_dict"] = state_dict
 
     # commit_description does not seem to work?
-    what_to_delete = ("use_temp_dir", "commit_message", "create_pr", "revision", "commit_description", "tags",) \
-        if not push_to_hub else \
-        ("use_temp_dir", "create_pr", "revision", "tags", "commit_description",)
+    what_to_delete = (
+        (
+            "use_temp_dir",
+            "commit_message",
+            "create_pr",
+            "revision",
+            "commit_description",
+            "tags",
+        )
+        if not push_to_hub
+        else (
+            "use_temp_dir",
+            "create_pr",
+            "revision",
+            "tags",
+            "commit_description",
+        )
+    )
     for deletion in what_to_delete:
         del save_pretrained_settings[deletion]
-    pass
     if hasattr(model, "add_model_tags"):
-        model.add_model_tags(["unsloth",])
+        model.add_model_tags(
+            [
+                "unsloth",
+            ]
+        )
 
     # Update model tag
     if push_to_hub:
         _ = upload_to_huggingface(
-            model, save_pretrained_settings["save_directory"], token,
-            "finetuned", "trl", file_location = None,
-            old_username = username, private = private,
+            model,
+            save_pretrained_settings["save_directory"],
+            token,
+            "finetuned",
+            "trl",
+            file_location = None,
+            old_username = username,
+            private = private,
         )
-    pass
 
     # First check if we're pushing to an organization!
     save_directory = save_pretrained_settings["save_directory"]
 
     if save_pretrained_settings["push_to_hub"]:
-        new_save_directory, new_username = _determine_username(save_directory, username, token)
+        new_save_directory, new_username = _determine_username(
+            save_directory, username, token
+        )
 
         if token is not None:
             from huggingface_hub import whoami
+
             actual_username = whoami(token = token)["name"]
         else:
             actual_username = username
-    pass
 
     # Check if pushing to an organization
     if save_pretrained_settings["push_to_hub"] and (username != actual_username):
@@ -665,7 +752,6 @@ def unsloth_save_model(
         # We upload everything at the end!
         tokenizer_save_settings["push_to_hub"] = False
         tokenizer_save_settings["save_directory"] = new_save_directory
-    pass
 
     # Save tokenizer
     if tokenizer is not None:
@@ -683,7 +769,6 @@ def unsloth_save_model(
         print(" Done.")
     else:
         print()
-    pass
 
     # Since merged, edit quantization_config
     old_config = model.config
@@ -722,12 +807,11 @@ def unsloth_save_model(
             path_in_repo = ".",
             repo_id = new_save_directory,
             repo_type = "model",
-            commit_message  = "(Trained with Unsloth)",
+            commit_message = "(Trained with Unsloth)",
             ignore_patterns = "*.md",
         )
     else:
         internal_model.save_pretrained(**save_pretrained_settings)
-    pass
 
     # Revert config back
     original_model = model
@@ -738,8 +822,9 @@ def unsloth_save_model(
     print("Done.")
 
     if push_to_hub and hasattr(model, "config"):
-        print(f"Saved merged model to https://huggingface.co/{username}/{save_directory.lstrip('/').split('/')[-1]}")
-    pass
+        print(
+            f"Saved merged model to https://huggingface.co/{username}/{save_directory.lstrip('/').split('/')[-1]}"
+        )
 
     save_pretrained_settings["state_dict"] = None
 
@@ -748,8 +833,6 @@ def unsloth_save_model(
         if j % 10 == 0:
             torch.cuda.empty_cache()
             gc.collect()
-        pass
-    pass
     state_dict = None
     del state_dict
     torch.cuda.empty_cache()
@@ -757,20 +840,26 @@ def unsloth_save_model(
 
     # Remove temporary location
     import shutil
+
     shutil.rmtree(temporary_location, ignore_errors = True)
 
     for _ in range(3):
         torch.cuda.empty_cache()
         gc.collect()
     return save_directory, username
-pass
 
 
 def install_llama_cpp_clone_non_blocking():
-    full_command = ["git", "clone", "--recursive", "https://github.com/ggerganov/llama.cpp"]
-    run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
+    full_command = [
+        "git",
+        "clone",
+        "--recursive",
+        "https://github.com/ggerganov/llama.cpp",
+    ]
+    run_installer = subprocess.Popen(
+        full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT
+    )
     return run_installer
-pass
 
 
 def install_llama_cpp_make_non_blocking():
@@ -782,71 +871,89 @@ def install_llama_cpp_make_non_blocking():
     IS_CMAKE = False
     if check == 0:
         # Uses old MAKE
-        n_jobs = max(int(psutil.cpu_count()*1.5), 1)
-        full_command = ["make", "all", "-j"+str(n_jobs), "-C", "llama.cpp"]
+        n_jobs = max(int(psutil.cpu_count() * 1.5), 1)
+        full_command = ["make", "all", "-j" + str(n_jobs), "-C", "llama.cpp"]
         IS_CMAKE = False
     else:
         # Uses new CMAKE
-        n_jobs = max(int(psutil.cpu_count()), 1) # Use less CPUs since 1.5x faster
-        check = os.system("cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON")
+        n_jobs = max(int(psutil.cpu_count()), 1)  # Use less CPUs since 1.5x faster
+        check = os.system(
+            "cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON"
+        )
         if check != 0:
-            raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp using os.system(...) with error {check}. Please report this ASAP!")
-        pass
+            raise RuntimeError(
+                f"*** Unsloth: Failed compiling llama.cpp using os.system(...) with error {check}. Please report this ASAP!"
+            )
         # f"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count()*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}",
         full_command = [
-            "cmake", "--build", "llama.cpp/build",
-            "--config", "Release",
-            "-j"+str(n_jobs),
+            "cmake",
+            "--build",
+            "llama.cpp/build",
+            "--config",
+            "Release",
+            "-j" + str(n_jobs),
             "--clean-first",
             "--target",
         ] + LLAMA_CPP_TARGETS
         IS_CMAKE = True
-    pass
     # https://github.com/ggerganov/llama.cpp/issues/7062
     # Weirdly GPU conversion for GGUF breaks??
     # run_installer = subprocess.Popen(full_command, env = env, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
-    run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
+    run_installer = subprocess.Popen(
+        full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT
+    )
     return run_installer, IS_CMAKE
-pass
 
 
 def install_python_non_blocking(packages = []):
     full_command = ["pip", "install"] + packages
-    run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
+    run_installer = subprocess.Popen(
+        full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT
+    )
     return run_installer
-pass
 
 
 def try_execute(commands, force_complete = False):
     for command in commands:
-        with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
+        with subprocess.Popen(
+            command,
+            shell = True,
+            stdout = subprocess.PIPE,
+            stderr = subprocess.STDOUT,
+            bufsize = 1,
+        ) as sp:
             for line in sp.stdout:
                 line = line.decode("utf-8", errors = "replace")
                 if "undefined reference" in line:
-                    raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!")
+                    raise RuntimeError(
+                        f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!"
+                    )
                 elif "deprecated" in line:
                     return "CMAKE"
                 elif "Unknown argument" in line:
-                    raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!")
+                    raise RuntimeError(
+                        f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!"
+                    )
                 elif "***" in line:
-                    raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!")
+                    raise RuntimeError(
+                        f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!"
+                    )
                 print(line, flush = True, end = "")
-            pass
             if force_complete and sp.returncode is not None and sp.returncode != 0:
                 raise subprocess.CalledProcessError(sp.returncode, sp.args)
-        pass
-    pass
     return None
-pass
 
 
 def install_llama_cpp_old(version = -10):
     # Download the 10th latest release since the latest might be broken!
     # FALLBACK mechanism
-    releases = subprocess.check_output(["git", "ls-remote", "--tags", "https://github.com/ggerganov/llama.cpp.git"])
+    releases = subprocess.check_output(
+        ["git", "ls-remote", "--tags", "https://github.com/ggerganov/llama.cpp.git"]
+    )
     releases = releases.decode("utf-8").replace("\t", " ").split("\n")
     for i, x in enumerate(releases):
-        if "refs/tags/b" not in x: break
+        if "refs/tags/b" not in x:
+            break
     releases = releases[:i]
     latest = releases[-1]
     version = releases[version].split(" ")[0]
@@ -854,17 +961,20 @@ def install_llama_cpp_old(version = -10):
     # Check if the llama.cpp exists
     if os.path.exists("llama.cpp"):
         print(
-            "**[WARNING]** You have a llama.cpp directory which is broken.\n"\
-            "Unsloth will DELETE the broken directory and install a new one.\n"\
+            "**[WARNING]** You have a llama.cpp directory which is broken.\n"
+            "Unsloth will DELETE the broken directory and install a new one.\n"
             "Press CTRL + C / cancel this if this is wrong. We shall wait 30 seconds.\n"
         )
         import time
+
         for i in range(30):
-            print(f"**[WARNING]** Deleting llama.cpp directory... {30-i} seconds left.")
+            print(
+                f"**[WARNING]** Deleting llama.cpp directory... {30 - i} seconds left."
+            )
             time.sleep(1)
         import shutil
+
         shutil.rmtree("llama.cpp", ignore_errors = True)
-    pass
 
     # Clone a specific commit
     # Also don't use the GPU!
@@ -877,35 +987,32 @@ def install_llama_cpp_old(version = -10):
     # Try using MAKE
     commands = [
         "make clean -C llama.cpp",
-        f"make all -j{psutil.cpu_count()*2} -C llama.cpp",
+        f"make all -j{psutil.cpu_count() * 2} -C llama.cpp",
     ]
     if try_execute(commands) == "CMAKE":
         # Instead use CMAKE
         commands = [
             "cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON",
-            f"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count()*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}",
+            f"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count() * 2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}",
             "cp llama.cpp/build/bin/llama-* llama.cpp",
             "rm -rf llama.cpp/build",
         ]
         try_execute(commands)
-    pass
 
     # Check if successful
     if not (
-        os.path.exists("llama.cpp/llama-quantize.exe") or
-        os.path.exists("llama.cpp/llama-quantize") or
-        os.path.exists("llama.cpp/quantize.exe") or
-        os.path.exists("llama.cpp/quantize") or
-        os.path.exists("llama.cpp/build/bin/llama-quantize") or
-        os.path.exists("llama.cpp/build/bin/quantize")
+        os.path.exists("llama.cpp/llama-quantize.exe")
+        or os.path.exists("llama.cpp/llama-quantize")
+        or os.path.exists("llama.cpp/quantize.exe")
+        or os.path.exists("llama.cpp/quantize")
+        or os.path.exists("llama.cpp/build/bin/llama-quantize")
+        or os.path.exists("llama.cpp/build/bin/quantize")
     ):
         raise RuntimeError(
-            "Unsloth: The file 'llama.cpp/llama-quantize' or `llama.cpp/quantize` does not exist.\n"\
-            "We've also double checked the building directory under 'llama.cpp/build/bin/'.\n"\
+            "Unsloth: The file 'llama.cpp/llama-quantize' or `llama.cpp/quantize` does not exist.\n"
+            "We've also double checked the building directory under 'llama.cpp/build/bin/'.\n"
             "But we expect this file to exist! Check if the file exists under llama.cpp and investigate the building process of llama.cpp (make/cmake)!"
         )
-    pass
-pass
 
 
 def install_llama_cpp_blocking(use_cuda = False):
@@ -917,7 +1024,8 @@ def install_llama_cpp_blocking(use_cuda = False):
         "git clone --recursive https://github.com/ggerganov/llama.cpp",
         "pip install gguf protobuf",
     ]
-    if os.path.exists("llama.cpp"): return
+    if os.path.exists("llama.cpp"):
+        return
     try_execute(commands)
 
     commands = [
@@ -925,19 +1033,17 @@ def install_llama_cpp_blocking(use_cuda = False):
         # https://github.com/ggerganov/llama.cpp/issues/7062
         # Weirdly GPU conversion for GGUF breaks??
         # f"{use_cuda} make all -j{psutil.cpu_count()*2} -C llama.cpp",
-        f"make all -j{psutil.cpu_count()*2} -C llama.cpp",
+        f"make all -j{psutil.cpu_count() * 2} -C llama.cpp",
     ]
     if try_execute(commands) == "CMAKE":
         # Instead use CMAKE
         commands = [
             "cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON",
-            f"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count()*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}",
+            f"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count() * 2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}",
             "cp llama.cpp/build/bin/llama-* llama.cpp",
             "rm -rf llama.cpp/build",
         ]
         try_execute(commands)
-    pass
-pass
 
 
 def get_executable(executables):
@@ -948,23 +1054,21 @@ def get_executable(executables):
         for executable in executables:
             path = os.path.join(directory, executable)
             # Check if the executable exists and is executable
-            if os.path.exists(path) and os.access(path, os.X_OK): return path
-        pass
-    pass
+            if os.path.exists(path) and os.access(path, os.X_OK):
+                return path
     return None
-pass
 
 
 def save_to_gguf(
-    model_name           : str,
-    model_type           : str,
-    model_dtype          : str,
-    is_sentencepiece     : bool = False,
-    model_directory      : str = "unsloth_finetuned_model",
-    quantization_method  = "fast_quantized", # Can be a list of options! ["q4_k_m", "q8_0", "q5_k_m"]
-    first_conversion     : str = None,
-    is_vlm               : bool = False,
-    is_gpt_oss           : bool = False,
+    model_name: str,
+    model_type: str,
+    model_dtype: str,
+    is_sentencepiece: bool = False,
+    model_directory: str = "unsloth_finetuned_model",
+    quantization_method = "fast_quantized",  # Can be a list of options! ["q4_k_m", "q8_0", "q5_k_m"]
+    first_conversion: str = None,
+    is_vlm: bool = False,
+    is_gpt_oss: bool = False,
 ):
     """
     Orchestrates the complete GGUF conversion process.
@@ -977,44 +1081,53 @@ def save_to_gguf(
         print_output = False
 
     # Validate model dtype
-    assert(model_dtype == "float16" or model_dtype == "bfloat16")
+    assert model_dtype == "float16" or model_dtype == "bfloat16"
     model_dtype = "f16" if model_dtype == "float16" else "bf16"
 
     # Convert quantization_method to list
-    if   isinstance(quantization_method, list):  pass
-    elif isinstance(quantization_method, str):   quantization_method = [ quantization_method, ]
-    elif isinstance(quantization_method, tuple): quantization_method = list(quantization_method)
+    if isinstance(quantization_method, list):
+        pass
+    elif isinstance(quantization_method, str):
+        quantization_method = [
+            quantization_method,
+        ]
+    elif isinstance(quantization_method, tuple):
+        quantization_method = list(quantization_method)
     else:
-        raise TypeError("Unsloth: quantization_method can only be a string or a list of strings")
-    pass
+        raise TypeError(
+            "Unsloth: quantization_method can only be a string or a list of strings"
+        )
 
     # Check if bfloat16 is supported
     if model_dtype == "bf16" and not torch.cuda.is_bf16_supported():
         logger.warning(
-            "Unsloth: Cannot convert to bf16 GGUF since your computer doesn't support it.\n"\
+            "Unsloth: Cannot convert to bf16 GGUF since your computer doesn't support it.\n"
             "We shall switch instead to f16."
         )
         model_dtype = "f16"
-    pass
 
     # Check first_conversion as well
     if first_conversion is None:
         first_conversion = model_dtype
-    pass
 
     # Check I quants
     for quant_method in quantization_method:
         if quant_method.startswith("iq2"):
-            raise RuntimeError("Unsloth: Currently iq2 type quantizations aren't supported yet - sorry!")
-    pass
+            raise RuntimeError(
+                "Unsloth: Currently iq2 type quantizations aren't supported yet - sorry!"
+            )
 
     # Map quant methods
     new_quantization_methods = []
     for quant_method in quantization_method:
-        if   quant_method == "not_quantized":  quant_method = model_dtype
-        elif quant_method == "fast_quantized": quant_method = "q8_0"
-        elif quant_method == "quantized":      quant_method = "q4_k_m"
-        elif quant_method is None:             quant_method = "q8_0"
+        if quant_method == "not_quantized":
+            quant_method = model_dtype
+        elif quant_method == "fast_quantized":
+            quant_method = "q8_0"
+        elif quant_method == "quantized":
+            quant_method = "q4_k_m"
+        elif quant_method is None:
+            quant_method = "q8_0"
 
         # Check if wrong method
         if quant_method not in ALLOWED_QUANTS.keys():
@@ -1022,10 +1135,8 @@ def save_to_gguf(
             for key, value in ALLOWED_QUANTS.items():
                 error += f"[{key}] => {value}\n"
             raise RuntimeError(error)
-        pass
 
         new_quantization_methods.append(quant_method)
-    pass
     quantization_method = new_quantization_methods
 
     # Determine optimal first_conversion
@@ -1044,15 +1155,22 @@ def save_to_gguf(
                 # that can be requantized to all requested formats
                 strength = 0
                 for quant_method in quantization_method:
-                    if   quant_method == "f32":  strength = max(strength, 3)
-                    elif quant_method == "f16":  strength = max(strength, 2)
-                    elif quant_method == "bf16": strength = max(strength, 1)
+                    if quant_method == "f32":
+                        strength = max(strength, 3)
+                    elif quant_method == "f16":
+                        strength = max(strength, 2)
+                    elif quant_method == "bf16":
+                        strength = max(strength, 1)
                     # Note: we don't set strength for q8_0 here since we handle it above
 
-                if   strength >= 3: first_conversion = "f32"
-                elif strength >= 2: first_conversion = "f16"
-                elif strength >= 1: first_conversion = "bf16"
-                else: first_conversion = "bf16" # requantizing from q8_0 disallowed in new llama.cpp default to bf16.
+                if strength >= 3:
+                    first_conversion = "f32"
+                elif strength >= 2:
+                    first_conversion = "f16"
+                elif strength >= 1:
+                    first_conversion = "bf16"
+                else:
+                    first_conversion = "bf16"  # requantizing from q8_0 disallowed in new llama.cpp default to bf16.
 
     # Check bfloat16 support again for first_conversion
     if first_conversion == "bf16" and not torch.cuda.is_bf16_supported():
@@ -1061,12 +1179,13 @@ def save_to_gguf(
 
     first_conversion_dtype = "" if first_conversion == "None" else first_conversion
     # Print conversion info
-    print_info = \
-        f"==((====))==  Unsloth: Conversion from HF to GGUF information\n"\
-        f"   {chr(92)}{chr(92)}   /|    [0] Installing llama.cpp might take 3 minutes.\n"\
-        f"O^O/ {chr(92)}_/ {chr(92)}    [1] Converting HF to GGUF {first_conversion_dtype} might take 3 minutes.\n"\
-        f"{chr(92)}        /    [2] Converting GGUF {first_conversion_dtype} to {quantization_method} might take 10 minutes each.\n"\
+    print_info = (
+        f"==((====))==  Unsloth: Conversion from HF to GGUF information\n"
+        f"   {chr(92)}{chr(92)}   /|    [0] Installing llama.cpp might take 3 minutes.\n"
+        f"O^O/ {chr(92)}_/ {chr(92)}    [1] Converting HF to GGUF {first_conversion_dtype} might take 3 minutes.\n"
+        f"{chr(92)}        /    [2] Converting GGUF {first_conversion_dtype} to {quantization_method} might take 10 minutes each.\n"
         f' "-____-"     In total, you will have to wait at least 16 minutes.\n'
+    )
     print(print_info)
 
     # Step 1: Ensure llama.cpp is installed
@@ -1078,36 +1197,39 @@ def save_to_gguf(
         if IS_KAGGLE_ENVIRONMENT:
             # Kaggle: no CUDA support due to environment limitations
             quantizer_location, converter_location = install_llama_cpp(
-                gpu_support=False,
-                print_output=print_output
+                gpu_support = False, print_output = print_output
             )
         else:
             quantizer_location, converter_location = install_llama_cpp(
-                gpu_support=False,  # GGUF conversion doesn't need CUDA
-                print_output=print_output
+                gpu_support = False,  # GGUF conversion doesn't need CUDA
+                print_output = print_output,
             )
 
     # Step 2: Download and patch converter script
     print("Unsloth: Preparing converter script...")
     with use_local_gguf():
-        converter_path, supported_text_archs, supported_vision_archs = _download_convert_hf_to_gguf()
+        converter_path, supported_text_archs, supported_vision_archs = (
+            _download_convert_hf_to_gguf()
+        )
 
         # Step 3: Initial GGUF conversion
-        print(f"Unsloth: [1] Converting model into {first_conversion_dtype} GGUF format.")
+        print(
+            f"Unsloth: [1] Converting model into {first_conversion_dtype} GGUF format."
+        )
         print(f"This might take 3 minutes...")
 
         initial_files, is_vlm_update = convert_to_gguf(
-            model_name=model_name,
-            input_folder=model_directory,
+            model_name = model_name,
+            input_folder = model_directory,
             model_dtype = model_dtype,
-            quantization_type=first_conversion,
-            converter_location=converter_path,
-            supported_text_archs=supported_text_archs,
-            supported_vision_archs=supported_vision_archs,
-            is_vlm=is_vlm,
-            is_gpt_oss=is_gpt_oss,
-            max_shard_size="50GB",
-            print_output=print_output,
+            quantization_type = first_conversion,
+            converter_location = converter_path,
+            supported_text_archs = supported_text_archs,
+            supported_vision_archs = supported_vision_archs,
+            is_vlm = is_vlm,
+            is_gpt_oss = is_gpt_oss,
+            max_shard_size = "50GB",
+            print_output = print_output,
         )
     # update is_vlm switch
     is_vlm = is_vlm_update
@@ -1134,7 +1256,8 @@ def save_to_gguf(
 
     # Get CPU count for quantization
     n_cpus = psutil.cpu_count()
-    if n_cpus is None: n_cpus = 1
+    if n_cpus is None:
+        n_cpus = 1
     n_cpus *= 2
 
     if not is_gpt_oss:
@@ -1142,48 +1265,46 @@ def save_to_gguf(
         quants_created = False
         for quant_method in quantization_method:
             if quant_method != first_conversion:
-                print(f"Unsloth: [2] Converting GGUF {first_conversion_dtype} into {quant_method}. This might take 10 minutes...")
+                print(
+                    f"Unsloth: [2] Converting GGUF {first_conversion_dtype} into {quant_method}. This might take 10 minutes..."
+                )
                 output_location = f"{model_name}.{quant_method.upper()}.gguf"
 
                 try:
                     # Use the quantize_gguf function we created
                     quantized_file = quantize_gguf(
-                        input_gguf=base_gguf,
-                        output_gguf=output_location,
-                        quant_type=quant_method,
-                        quantizer_location=quantizer_location,
-                        print_output=print_output
+                        input_gguf = base_gguf,
+                        output_gguf = output_location,
+                        quant_type = quant_method,
+                        quantizer_location = quantizer_location,
+                        print_output = print_output,
                     )
                     all_saved_locations.append(quantized_file)
                     quants_created = True
                 except Exception as e:
                     if IS_KAGGLE_ENVIRONMENT:
                         raise RuntimeError(
-                            f"Unsloth: Quantization failed for {output_location}\n"\
-                            "You are in a Kaggle environment, which might be the reason this is failing.\n"\
-                            "Kaggle only provides 20GB of disk space in the working directory.\n"\
-                            "Merging to 16bit for 7b models use 16GB of space.\n"\
-                            "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\
-                            "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\
-                            "You can try saving it to the `/tmp` directory for larger disk space.\n"\
-                            "I suggest you to save the 16bit model first, then use manual llama.cpp conversion.\n"\
+                            f"Unsloth: Quantization failed for {output_location}\n"
+                            "You are in a Kaggle environment, which might be the reason this is failing.\n"
+                            "Kaggle only provides 20GB of disk space in the working directory.\n"
+                            "Merging to 16bit for 7b models use 16GB of space.\n"
+                            "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"
+                            "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"
+                            "You can try saving it to the `/tmp` directory for larger disk space.\n"
+                            "I suggest you to save the 16bit model first, then use manual llama.cpp conversion.\n"
                             "Error: {e}"
                         )
                     else:
                         raise RuntimeError(
-                            f"Unsloth: Quantization failed for {output_location}\n"\
-                            "You might have to compile llama.cpp yourself, then run this again.\n"\
-                            "You do not need to close this Python program. Run the following commands in a new terminal:\n"\
-                            "You must run this in the same folder as you're saving your model.\n"\
-                            "git clone --recursive https://github.com/ggerganov/llama.cpp\n"\
-                            "cd llama.cpp && make clean && make all -j\n"\
-                            "Once that's done, redo the quantization.\n"\
+                            f"Unsloth: Quantization failed for {output_location}\n"
+                            "You might have to compile llama.cpp yourself, then run this again.\n"
+                            "You do not need to close this Python program. Run the following commands in a new terminal:\n"
+                            "You must run this in the same folder as you're saving your model.\n"
+                            "git clone --recursive https://github.com/ggerganov/llama.cpp\n"
+                            "cd llama.cpp && make clean && make all -j\n"
+                            "Once that's done, redo the quantization.\n"
                             "Error: {e}"
                         )
-                    pass
-                pass
-            pass
-        pass
         print("Unsloth: Model files cleanup...")
         if quants_created:
             all_saved_locations.remove(base_gguf)
@@ -1193,7 +1314,6 @@ def save_to_gguf(
             all_saved_locations.reverse()
     else:
         print("Unsloth: GPT-OSS model - skipping additional quantizations")
-    pass
 
     if is_gpt_oss:
         want_full_precision = True
@@ -1204,42 +1324,40 @@ def save_to_gguf(
     print(f"Generated files: {all_saved_locations}")
 
     return all_saved_locations, want_full_precision, is_vlm
-pass
 
 
 def unsloth_save_pretrained_merged(
     self,
-    save_directory       : Union[str, os.PathLike],
-    tokenizer            = None,
-    save_method          : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
-    push_to_hub          : bool = False,
-    token                : Optional[Union[str, bool]] = None,
-    is_main_process      : bool = True,
-    state_dict           : Optional[dict] = None,
-    save_function        : Callable = torch.save,
-    max_shard_size       : Union[int, str] = "5GB",
-    safe_serialization   : bool = True,
-    variant              : Optional[str] = None,
-    save_peft_format     : bool = True,
-    tags                 : List[str] = None,
-    temporary_location   : str = "_unsloth_temporary_saved_buffers",
-    maximum_memory_usage : float = 0.75,
+    save_directory: Union[str, os.PathLike],
+    tokenizer = None,
+    save_method: str = "merged_16bit",  # ["lora", "merged_16bit", "merged_4bit"]
+    push_to_hub: bool = False,
+    token: Optional[Union[str, bool]] = None,
+    is_main_process: bool = True,
+    state_dict: Optional[dict] = None,
+    save_function: Callable = torch.save,
+    max_shard_size: Union[int, str] = "5GB",
+    safe_serialization: bool = True,
+    variant: Optional[str] = None,
+    save_peft_format: bool = True,
+    tags: List[str] = None,
+    temporary_location: str = "_unsloth_temporary_saved_buffers",
+    maximum_memory_usage: float = 0.75,
 ):
     """
-        Same as .save_pretrained(...) except 4bit weights are auto
-        converted to float16 with as few overhead as possible.
+    Same as .save_pretrained(...) except 4bit weights are auto
+    converted to float16 with as few overhead as possible.
 
-        Choose for `save_method` to be either:
-        1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
-        2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
-        3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.
+    Choose for `save_method` to be either:
+    1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
+    2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
+    3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.
     """
     if tokenizer is None:
         logger.warning_once(
-            "Unsloth: You're not saving a tokenizer as well?\n"\
+            "Unsloth: You're not saving a tokenizer as well?\n"
             "You can do it separately via `tokenizer.save_pretrained(...)`"
         )
-    pass
 
     arguments = dict(locals())
     arguments["model"] = self
@@ -1247,57 +1365,53 @@ def unsloth_save_pretrained_merged(
     unsloth_save_model(**arguments)
     for _ in range(3):
         gc.collect()
-pass
 
 
 def unsloth_push_to_hub_merged(
     self,
-    repo_id              : str,
-    tokenizer            = None,
-    save_method          : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
-    use_temp_dir         : Optional[bool] = None,
-    commit_message       : Optional[str] = "Trained with Unsloth",
-    private              : Optional[bool] = None,
-    token                : Union[bool, str, None] = None,
-    max_shard_size       : Union[int, str, None] = "5GB",
-    create_pr            : bool = False,
-    safe_serialization   : bool = True,
-    revision             : str = None,
-    commit_description   : str = "Upload model trained with Unsloth 2x faster",
-    tags                 : Optional[List[str]] = None,
-    temporary_location   : str = "_unsloth_temporary_saved_buffers",
-    maximum_memory_usage : float = 0.75,
+    repo_id: str,
+    tokenizer = None,
+    save_method: str = "merged_16bit",  # ["lora", "merged_16bit", "merged_4bit"]
+    use_temp_dir: Optional[bool] = None,
+    commit_message: Optional[str] = "Trained with Unsloth",
+    private: Optional[bool] = None,
+    token: Union[bool, str, None] = None,
+    max_shard_size: Union[int, str, None] = "5GB",
+    create_pr: bool = False,
+    safe_serialization: bool = True,
+    revision: str = None,
+    commit_description: str = "Upload model trained with Unsloth 2x faster",
+    tags: Optional[List[str]] = None,
+    temporary_location: str = "_unsloth_temporary_saved_buffers",
+    maximum_memory_usage: float = 0.75,
 ):
     """
-        Same as .push_to_hub(...) except 4bit weights are auto
-        converted to float16 with as few overhead as possible.
+    Same as .push_to_hub(...) except 4bit weights are auto
+    converted to float16 with as few overhead as possible.
 
-        Choose for `save_method` to be either:
-        1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
-        2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
-        3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.
+    Choose for `save_method` to be either:
+    1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
+    2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
+    3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.
     """
     if tokenizer is None:
         logger.warning_once(
-            "Unsloth: You're not saving a tokenizer as well?\n"\
+            "Unsloth: You're not saving a tokenizer as well?\n"
             "You can do it separately via `tokenizer.push_to_hub(...)`"
         )
-    pass
 
     arguments = dict(locals())
-    arguments["model"]          = self
+    arguments["model"] = self
     arguments["save_directory"] = repo_id
-    arguments["push_to_hub"]    = True
+    arguments["push_to_hub"] = True
     del arguments["self"]
     del arguments["repo_id"]
     unsloth_save_model(**arguments)
     for _ in range(3):
         gc.collect()
-pass
 
 
-MODEL_CARD = \
-"""---
+MODEL_CARD = """---
 base_model: {base_model}
 tags:
 - text-generation-inference
@@ -1327,19 +1441,19 @@ def _determine_username(save_directory, old_username, token):
     save_directory = save_directory.lstrip("./")
     if "/" not in save_directory:
         from huggingface_hub import whoami
+
         try:
             username = whoami(token = token)["name"]
             if type(old_username) is str and username != old_username:
                 username = old_username
-            pass
             save_directory = f"{username}/{save_directory}"
         except:
-            raise RuntimeError(f"Unsloth: {save_directory} is not a Huggingface directory.")
+            raise RuntimeError(
+                f"Unsloth: {save_directory} is not a Huggingface directory."
+            )
     else:
         username = save_directory.split("/")[0]
-    pass
     return save_directory, username
-pass
 
 
 def create_huggingface_repo(
@@ -1348,29 +1462,30 @@ def create_huggingface_repo(
     token = None,
     private = False,
 ):
-    if token is None :
+    if token is None:
         token = get_token()
-    pass
     save_directory, username = _determine_username(save_directory, "", token)
 
     from huggingface_hub import create_repo
+
     try:
         create_repo(
-            repo_id   = save_directory,
-            token     = token,
+            repo_id = save_directory,
+            token = token,
             repo_type = "model",
-            exist_ok  = False,
-            private   = private,
+            exist_ok = False,
+            private = private,
         )
 
         # Create model card
         from huggingface_hub import ModelCard
+
         content = MODEL_CARD.format(
-            username   = username,
+            username = username,
             base_model = model.config._name_or_path,
             model_type = model.config.model_type,
-            method     = "",
-            extra      = "unsloth",
+            method = "",
+            extra = "unsloth",
         )
         card = ModelCard(content)
         card.push_to_hub(save_directory, token = token)
@@ -1378,7 +1493,6 @@ def create_huggingface_repo(
         pass
     hf_api = HfApi(token = token)
     return save_directory, hf_api
-pass
 
 
 def upload_to_huggingface(
@@ -1395,23 +1509,25 @@ def upload_to_huggingface(
     save_directory, username = _determine_username(save_directory, old_username, token)
 
     from huggingface_hub import create_repo
+
     try:
         create_repo(
-            repo_id   = save_directory,
-            token     = token,
+            repo_id = save_directory,
+            token = token,
             repo_type = "model",
-            exist_ok  = False,
-            private   = private,
+            exist_ok = False,
+            private = private,
         )
 
         # Create model card
         from huggingface_hub import ModelCard
+
         content = MODEL_CARD.format(
-            username   = username,
+            username = username,
             base_model = model.config._name_or_path,
             model_type = model.config.model_type,
-            method     = "",
-            extra      = extra,
+            method = "",
+            extra = extra,
         )
         card = ModelCard(content)
         card.push_to_hub(save_directory, token = token)
@@ -1423,53 +1539,51 @@ def upload_to_huggingface(
         hf_api = HfApi(token = token)
 
         if "/" in file_location:
-            uploaded_location = file_location[file_location.rfind("/")+1:]
+            uploaded_location = file_location[file_location.rfind("/") + 1 :]
         else:
             uploaded_location = file_location
-        pass
 
         # find ftevent file from tensorboard and upload it
         import glob
+
         ftevent_files = glob.glob("*out.tfevents*", recursive = True)
         if len(ftevent_files) > 0:
-            print("Unsloth: Uploading tensorboard files... Please wait...", file_location + "*out.tfevents*")
+            print(
+                "Unsloth: Uploading tensorboard files... Please wait...",
+                file_location + "*out.tfevents*",
+            )
             for ftevent_file in ftevent_files:
                 hf_api.upload_file(
                     path_or_fileobj = ftevent_file,
-                    path_in_repo    = ftevent_file.replace(file_location, ""),
-                    repo_id         = save_directory,
-                    repo_type       = "model",
-                    commit_message  = "(Trained with Unsloth)",
+                    path_in_repo = ftevent_file.replace(file_location, ""),
+                    repo_id = save_directory,
+                    repo_type = "model",
+                    commit_message = "(Trained with Unsloth)",
                 )
-            pass
-        pass
 
         hf_api.upload_file(
             path_or_fileobj = file_location,
-            path_in_repo    = uploaded_location,
-            repo_id         = save_directory,
-            repo_type       = "model",
-            commit_message  = "(Trained with Unsloth)",
+            path_in_repo = uploaded_location,
+            repo_id = save_directory,
+            repo_type = "model",
+            commit_message = "(Trained with Unsloth)",
         )
 
         # We also upload a config.json file
         if create_config:
             import json
+
             with open("_temporary_unsloth_config.json", "w", encoding = "utf-8") as file:
-                json.dump({"model_type" : model.config.model_type}, file, indent = 4)
-            pass
+                json.dump({"model_type": model.config.model_type}, file, indent = 4)
             hf_api.upload_file(
                 path_or_fileobj = "_temporary_unsloth_config.json",
-                path_in_repo    = "config.json",
-                repo_id         = save_directory,
-                repo_type       = "model",
-                commit_message  = "(Trained with Unsloth)",
+                path_in_repo = "config.json",
+                repo_id = save_directory,
+                repo_type = "model",
+                commit_message = "(Trained with Unsloth)",
             )
             os.remove("_temporary_unsloth_config.json")
-        pass
-    pass
     return username
-pass
 
 
 def fix_tokenizer_bos_token(tokenizer):
@@ -1477,95 +1591,97 @@ def fix_tokenizer_bos_token(tokenizer):
     fix_bos_token = False
     chat_template = getattr(tokenizer, "chat_template", None)
 
-    if (tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None)):
-        if chat_template is not None and \
-            (
-                tokenizer.bos_token in chat_template or \
-                "{bos_token}" in chat_template.replace(" ", "") or \
-                "{bos_token+" in chat_template.replace(" ", "")
-            ):
-
+    if tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None):
+        if chat_template is not None and (
+            tokenizer.bos_token in chat_template
+            or "{bos_token}" in chat_template.replace(" ", "")
+            or "{bos_token+" in chat_template.replace(" ", "")
+        ):
             fix_bos_token = True
             logger.warning(
-                "Unsloth: ##### The current model auto adds a BOS token.\n"\
+                "Unsloth: ##### The current model auto adds a BOS token.\n"
                 "Unsloth: ##### Your chat template has a BOS token. We shall remove it temporarily."
             )
 
             # Remove {{bos_token}}
-            new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\}[\s]{0,}\}", "", chat_template)
+            new_chat_template = re.sub(
+                r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\}[\s]{0,}\}", "", chat_template
+            )
             # Remove {{bos_token +
-            new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\+[\s]{0,}", "", new_chat_template)
+            new_chat_template = re.sub(
+                r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\+[\s]{0,}",
+                "",
+                new_chat_template,
+            )
 
             tokenizer.chat_template = new_chat_template
 
-        pass
-    pass
     return fix_bos_token, chat_template
-pass
 
 
 def create_ollama_modelfile(tokenizer, base_model_name, model_location):
     """
-        Creates an Ollama Modelfile.
-        Use ollama.create(model = "new_ollama_model", modelfile = modelfile)
+    Creates an Ollama Modelfile.
+    Use ollama.create(model = "new_ollama_model", modelfile = modelfile)
     """
     ollama_template_name = MODEL_TO_OLLAMA_TEMPLATE_MAPPER.get(base_model_name)
     if not ollama_template_name:
-        print(f"Unsloth: No Ollama template mapping found for model '{base_model_name}'. Skipping Ollama Modelfile")
+        print(
+            f"Unsloth: No Ollama template mapping found for model '{base_model_name}'. Skipping Ollama Modelfile"
+        )
         return None
     ollama_modelfile = OLLAMA_TEMPLATES.get(ollama_template_name)
     if not ollama_modelfile:
-        print(f"Unsloth: No Ollama template mapping found for model '{base_model_name}'. Skipping Ollama Modelfile")
+        print(
+            f"Unsloth: No Ollama template mapping found for model '{base_model_name}'. Skipping Ollama Modelfile"
+        )
         return None
-    tokenizer._ollama_modelfile = ollama_modelfile  # This comes from the unpacking above
+    tokenizer._ollama_modelfile = (
+        ollama_modelfile  # This comes from the unpacking above
+    )
     modelfile = ollama_modelfile
 
     FILE_LOCATION_REPLACER = "ā«@ā
#š¦„__FILE_LOCATION__ā”@š¦„#āµ"
-    EOS_TOKEN_REPLACER     = "ā«@ā
#š¦„__EOS_TOKEN__ā”@š¦„#āµ"
-    LEFT_BRACKET_REPLACER  = "ā«@ā
#š¦„"
+    EOS_TOKEN_REPLACER = "ā«@ā
#š¦„__EOS_TOKEN__ā”@š¦„#āµ"
+    LEFT_BRACKET_REPLACER = "ā«@ā
#š¦„"
     RIGHT_BRACKET_REPLACER = "ā”@š¦„#āµ"
 
     # Fixes https://github.com/unslothai/unsloth/issues/1087
     # We must convert all {'s and }'s but keep {__FILE_LOCATION__} intact
-    modelfile = modelfile\
-        .replace("{__FILE_LOCATION__}", FILE_LOCATION_REPLACER)\
-        .replace("{__EOS_TOKEN__}",     EOS_TOKEN_REPLACER)\
-        .replace("{", LEFT_BRACKET_REPLACER)\
+    modelfile = (
+        modelfile.replace("{__FILE_LOCATION__}", FILE_LOCATION_REPLACER)
+        .replace("{__EOS_TOKEN__}", EOS_TOKEN_REPLACER)
+        .replace("{", LEFT_BRACKET_REPLACER)
         .replace("}", RIGHT_BRACKET_REPLACER)
+    )
 
     # Revert {__FILE_LOCATION__} back
-    modelfile = modelfile\
-        .replace(FILE_LOCATION_REPLACER, "{__FILE_LOCATION__}")\
-        .replace(EOS_TOKEN_REPLACER,     "{__EOS_TOKEN__}")
+    modelfile = modelfile.replace(
+        FILE_LOCATION_REPLACER, "{__FILE_LOCATION__}"
+    ).replace(EOS_TOKEN_REPLACER, "{__EOS_TOKEN__}")
 
     if "__EOS_TOKEN__" in modelfile:
         modelfile = modelfile.format(
-            __FILE_LOCATION__  = model_location,
-            __EOS_TOKEN__      = tokenizer.eos_token,
+            __FILE_LOCATION__ = model_location,
+            __EOS_TOKEN__ = tokenizer.eos_token,
         )
     else:
         modelfile = modelfile.format(
-            __FILE_LOCATION__  = model_location,
+            __FILE_LOCATION__ = model_location,
         )
-    pass
 
-    modelfile = modelfile\
-        .replace("ā«@ā
#š¦„", "{")\
-        .replace("ā”@š¦„#āµ", "}")\
-        .rstrip()
+    modelfile = modelfile.replace("ā«@ā
#š¦„", "{").replace("ā”@š¦„#āµ", "}").rstrip()
 
     return modelfile
-pass
 
-def create_ollama_model(
-    username: str,
-    model_name: str,
-    tag: str,
-    modelfile_path: str
-):
+
+def create_ollama_model(username: str, model_name: str, tag: str, modelfile_path: str):
     try:
         init_check = subprocess.run(
-            ['curl', 'http://localhost:11434'], capture_output=True, text=True,  timeout=3
+            ["curl", "http://localhost:11434"],
+            capture_output = True,
+            text = True,
+            timeout = 3,
         )
         if init_check.returncode == 0:
             print(init_check.stdout.strip())
@@ -1575,16 +1691,22 @@ def create_ollama_model(
         return "Ollama Request Timeout"
 
     process = subprocess.Popen(
-            ['ollama', 'create', f'{username}/{model_name}:{tag}', '-f', f'{modelfile_path}'],
-        stdout=subprocess.PIPE,
-        stderr=subprocess.STDOUT,
-        text=True,
-        bufsize=1,
-        universal_newlines=True
+        [
+            "ollama",
+            "create",
+            f"{username}/{model_name}:{tag}",
+            "-f",
+            f"{modelfile_path}",
+        ],
+        stdout = subprocess.PIPE,
+        stderr = subprocess.STDOUT,
+        text = True,
+        bufsize = 1,
+        universal_newlines = True,
     )
 
-    for line in iter(process.stdout.readline, ''):
-        print(line, end='')
+    for line in iter(process.stdout.readline, ""):
+        print(line, end = "")
         sys.stdout.flush()
 
     return_code = process.wait()
@@ -1593,13 +1715,15 @@ def create_ollama_model(
         print(f"\nMODEL CREATED FAILED WITH RETURN CODE {return_code}")
     else:
         print("\nMODEL CREATED SUCCESSFULLY")
-pass
 
 
 def push_to_ollama_hub(username: str, model_name: str, tag: str):
     try:
         init_check = subprocess.run(
-            ['curl', 'http://localhost:11434'], capture_output=True, text=True,  timeout=3
+            ["curl", "http://localhost:11434"],
+            capture_output = True,
+            text = True,
+            timeout = 3,
         )
         if init_check.returncode == 0:
             print(init_check.stdout.strip())
@@ -1609,16 +1733,16 @@ def push_to_ollama_hub(username: str, model_name: str, tag: str):
         return "Ollama Request Timeout"
 
     process = subprocess.Popen(
-            ['ollama', 'push', f'{username}/{model_name}:{tag}'],
-        stdout=subprocess.PIPE,
-        stderr=subprocess.STDOUT,
-        text=True,
-        bufsize=1,
-        universal_newlines=True
+        ["ollama", "push", f"{username}/{model_name}:{tag}"],
+        stdout = subprocess.PIPE,
+        stderr = subprocess.STDOUT,
+        text = True,
+        bufsize = 1,
+        universal_newlines = True,
     )
 
-    for line in iter(process.stdout.readline, ''):
-        print(line, end='')
+    for line in iter(process.stdout.readline, ""):
+        print(line, end = "")
         sys.stdout.flush()
 
     return_code = process.wait()
@@ -1627,18 +1751,11 @@ def push_to_ollama_hub(username: str, model_name: str, tag: str):
         print(f"\nMODEL PUBLISHED FAILED WITH RETURN CODE {return_code}")
     else:
         print("\nMODEL PUBLISHED SUCCESSFULLY")
-pass
 
-def push_to_ollama(
-    tokenizer,
-    gguf_location,
-    username: str,
-    model_name: str,
-    tag: str
-):
+
+def push_to_ollama(tokenizer, gguf_location, username: str, model_name: str, tag: str):
     model_file = create_ollama_modelfile(
-        tokenizer=tokenizer,
-        gguf_location=gguf_location
+        tokenizer = tokenizer, gguf_location = gguf_location
     )
 
     with open(f"Modelfile_{model_name}", "w", encoding = "utf-8") as f:
@@ -1646,78 +1763,73 @@ def push_to_ollama(
         f.close()
 
     create_ollama_model(
-        username=username,
-        model_name=model_name,
-        tag=tag,
-        modelfile_path=f"Modelfile_{model_name}"
+        username = username,
+        model_name = model_name,
+        tag = tag,
+        modelfile_path = f"Modelfile_{model_name}",
     )
 
-    push_to_ollama_hub(
-        username=username,
-        model_name=model_name,
-        tag=tag
-    )
+    push_to_ollama_hub(username = username, model_name = model_name, tag = tag)
 
     print("Successfully pushed to ollama")
-pass
 
 
 def unsloth_save_pretrained_gguf(
     self,
-    save_directory       : Union[str, os.PathLike],
-    tokenizer            = None,
-    quantization_method  = "fast_quantized",
-    first_conversion     : str = None,
-    push_to_hub          : bool = False,
-    token                : Optional[Union[str, bool]] = None,
-    private              : Optional[bool] = None,
-    is_main_process      : bool = True,
-    state_dict           : Optional[dict] = None,
-    save_function        : Callable = torch.save,
-    max_shard_size       : Union[int, str] = "5GB",
-    safe_serialization   : bool = True,
-    variant              : Optional[str] = None,
-    save_peft_format     : bool = True,
-    tags                 : List[str] = None,
-    temporary_location   : str = "_unsloth_temporary_saved_buffers",
-    maximum_memory_usage : float = 0.85,
+    save_directory: Union[str, os.PathLike],
+    tokenizer = None,
+    quantization_method = "fast_quantized",
+    first_conversion: str = None,
+    push_to_hub: bool = False,
+    token: Optional[Union[str, bool]] = None,
+    private: Optional[bool] = None,
+    is_main_process: bool = True,
+    state_dict: Optional[dict] = None,
+    save_function: Callable = torch.save,
+    max_shard_size: Union[int, str] = "5GB",
+    safe_serialization: bool = True,
+    variant: Optional[str] = None,
+    save_peft_format: bool = True,
+    tags: List[str] = None,
+    temporary_location: str = "_unsloth_temporary_saved_buffers",
+    maximum_memory_usage: float = 0.85,
 ):
     """
-        Same as .save_pretrained(...) except 4bit weights are auto
-        converted to float16 then converted to GGUF / llama.cpp format.
-
-        Choose for `quantization_method` to be:
-        "not_quantized"  : "Recommended. Fast conversion. Slow inference, big files.",
-        "fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
-        "quantized"      : "Recommended. Slow conversion. Fast inference, small files.",
-        "f32"     : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
-        "f16"     : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
-        "q8_0"    : "Fast conversion. High resource use, but generally acceptable.",
-        "q4_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
-        "q5_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
-        "q2_k"    : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
-        "q3_k_l"  : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
-        "q3_k_m"  : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
-        "q3_k_s"  : "Uses Q3_K for all tensors",
-        "q4_0"    : "Original quant method, 4-bit.",
-        "q4_1"    : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
-        "q4_k_s"  : "Uses Q4_K for all tensors",
-        "q4_k"    : "alias for q4_k_m",
-        "q5_k"    : "alias for q5_k_m",
-        "q5_0"    : "Higher accuracy, higher resource usage and slower inference.",
-        "q5_1"    : "Even higher accuracy, resource usage and slower inference.",
-        "q5_k_s"  : "Uses Q5_K for all tensors",
-        "q6_k"    : "Uses Q8_K for all tensors",
-        "iq2_xxs" : "2.06 bpw quantization",
-        "iq2_xs"  : "2.31 bpw quantization",
-        "iq3_xxs" : "3.06 bpw quantization",
-        "q3_k_xs" : "3-bit extra small quantization",
+    Same as .save_pretrained(...) except 4bit weights are auto
+    converted to float16 then converted to GGUF / llama.cpp format.
+
+    Choose for `quantization_method` to be:
+    "not_quantized"  : "Recommended. Fast conversion. Slow inference, big files.",
+    "fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
+    "quantized"      : "Recommended. Slow conversion. Fast inference, small files.",
+    "f32"     : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
+    "f16"     : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
+    "q8_0"    : "Fast conversion. High resource use, but generally acceptable.",
+    "q4_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
+    "q5_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
+    "q2_k"    : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
+    "q3_k_l"  : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+    "q3_k_m"  : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+    "q3_k_s"  : "Uses Q3_K for all tensors",
+    "q4_0"    : "Original quant method, 4-bit.",
+    "q4_1"    : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
+    "q4_k_s"  : "Uses Q4_K for all tensors",
+    "q4_k"    : "alias for q4_k_m",
+    "q5_k"    : "alias for q5_k_m",
+    "q5_0"    : "Higher accuracy, higher resource usage and slower inference.",
+    "q5_1"    : "Even higher accuracy, resource usage and slower inference.",
+    "q5_k_s"  : "Uses Q5_K for all tensors",
+    "q6_k"    : "Uses Q8_K for all tensors",
+    "iq2_xxs" : "2.06 bpw quantization",
+    "iq2_xs"  : "2.31 bpw quantization",
+    "iq3_xxs" : "3.06 bpw quantization",
+    "q3_k_xs" : "3-bit extra small quantization",
     """
     if tokenizer is None:
         raise ValueError("Unsloth: Saving to GGUF must have a tokenizer.")
 
     try:
-        base_model_name = get_model_name(self.config._name_or_path, load_in_4bit=False)
+        base_model_name = get_model_name(self.config._name_or_path, load_in_4bit = False)
         model_name = base_model_name.split("/")[-1]
     except:
         base_model_name = self.config._name_or_path
@@ -1731,7 +1843,7 @@ def unsloth_save_pretrained_gguf(
 
     # Step 1: Check if this is a VLM (Vision-Language Model) and check if gpt-oss
     is_vlm = False
-    if hasattr(self, 'config') and hasattr(self.config, 'architectures'):
+    if hasattr(self, "config") and hasattr(self.config, "architectures"):
         is_vlm = any(
             x.endswith(("ForConditionalGeneration", "ForVisionText2Text"))
             for x in self.config.architectures
@@ -1740,12 +1852,23 @@ def unsloth_save_pretrained_gguf(
 
     is_processor = is_vlm and isinstance(tokenizer, ProcessorMixin)
 
-    is_gpt_oss = True if (hasattr(self.config, "architectures") and self.config.architectures == "GptOssForCausalLM") or (hasattr(self.config, "model_type") and self.config.model_type in ["gpt-oss", "gpt_oss"]) else False
+    is_gpt_oss = (
+        True
+        if (
+            hasattr(self.config, "architectures")
+            and self.config.architectures == "GptOssForCausalLM"
+        )
+        or (
+            hasattr(self.config, "model_type")
+            and self.config.model_type in ["gpt-oss", "gpt_oss"]
+        )
+        else False
+    )
     # Step 2: Prepare arguments for model saving
     arguments = dict(locals())
-    arguments["model"]        = self
-    arguments["tokenizer"]    = tokenizer
-    arguments["push_to_hub"]  = False # We handle upload ourselves
+    arguments["model"] = self
+    arguments["tokenizer"] = tokenizer
+    arguments["push_to_hub"] = False  # We handle upload ourselves
     # GPT-OSS needs mxfp4 save method
     if is_gpt_oss:
         arguments["save_method"] = "mxfp4"
@@ -1767,7 +1890,9 @@ def unsloth_save_pretrained_gguf(
         fix_bos_token, old_chat_template = fix_tokenizer_bos_token(tokenizer)
 
     # Step 4: Save/merge model to 16-bit format
-    print(f'Unsloth: Merging model weights to {"mxfp4" if is_gpt_oss else "16-bit"} format...')
+    print(
+        f"Unsloth: Merging model weights to {'mxfp4' if is_gpt_oss else '16-bit'} format..."
+    )
     try:
         # Call unsloth_generic_save directly (it's in the same file)
         unsloth_generic_save(**arguments)
@@ -1781,11 +1906,11 @@ def unsloth_save_pretrained_gguf(
     # Use old chat template if the bos is removed
     if fix_bos_token:
         tokenizer.chat_template = old_chat_template
-    pass
 
     # Step 6: Clean up memory
     for _ in range(3):
         import gc
+
         gc.collect()
         if torch.cuda.is_available():
             torch.cuda.empty_cache()
@@ -1793,9 +1918,9 @@ def unsloth_save_pretrained_gguf(
     # Step 7: Get model dtype and type
     try:
         model_dtype = dtype_from_config(self.config)
-        model_type  = self.config.model_type
+        model_type = self.config.model_type
         if type(model_dtype) is str:
-            assert(model_dtype == "float16" or model_dtype == "bfloat16")
+            assert model_dtype == "float16" or model_dtype == "bfloat16"
         elif model_dtype == torch.float16:
             model_dtype = "float16"
         elif model_dtype == torch.bfloat16:
@@ -1815,33 +1940,41 @@ def unsloth_save_pretrained_gguf(
     quantization_methods = []
     if quantization_method is not None:
         # Convert quantization_method to list
-        if   isinstance(quantization_method, list):  pass
-        elif isinstance(quantization_method, str):   quantization_method = [ quantization_method, ]
-        elif isinstance(quantization_method, tuple): quantization_method = list(quantization_method)
+        if isinstance(quantization_method, list):
+            pass
+        elif isinstance(quantization_method, str):
+            quantization_method = [
+                quantization_method,
+            ]
+        elif isinstance(quantization_method, tuple):
+            quantization_method = list(quantization_method)
         else:
-            raise TypeError("Unsloth: quantization_method can only be a string or a list of strings")
-        pass
+            raise TypeError(
+                "Unsloth: quantization_method can only be a string or a list of strings"
+            )
         for i, quant_method in enumerate(quantization_method):
             quant_method = quant_method.lower()
-            if   quant_method == "not_quantized":  quant_method = "f16"
-            elif quant_method == "fast_quantized": quant_method = "q8_0"
-            elif quant_method == "quantized":      quant_method = "q4_k_m"
-            elif quant_method is None:             quant_method = "q8_0"
+            if quant_method == "not_quantized":
+                quant_method = "f16"
+            elif quant_method == "fast_quantized":
+                quant_method = "q8_0"
+            elif quant_method == "quantized":
+                quant_method = "q4_k_m"
+            elif quant_method is None:
+                quant_method = "q8_0"
             quantization_methods.append(quant_method.lower())
-        pass
-    pass
 
     try:
         all_file_locations, want_full_precision, is_vlm_update = save_to_gguf(
-            model_name=model_name,
-            model_type=model_type,
-            model_dtype=model_dtype,
-            is_sentencepiece=False,
-            model_directory=save_directory,
-            quantization_method=quantization_methods,
-            first_conversion=first_conversion,
-            is_vlm=is_vlm,  # Pass VLM flag
-            is_gpt_oss = is_gpt_oss, # Pass gpt_oss Flag
+            model_name = model_name,
+            model_type = model_type,
+            model_dtype = model_dtype,
+            is_sentencepiece = False,
+            model_directory = save_directory,
+            quantization_method = quantization_methods,
+            first_conversion = first_conversion,
+            is_vlm = is_vlm,  # Pass VLM flag
+            is_gpt_oss = is_gpt_oss,  # Pass gpt_oss Flag
         )
     except Exception as e:
         if IS_KAGGLE_ENVIRONMENT:
@@ -1862,7 +1995,9 @@ def unsloth_save_pretrained_gguf(
             if is_vlm_update:
                 modelfile = create_ollama_modelfile(tokenizer, base_model_name, ".")
             else:
-                modelfile = create_ollama_modelfile(tokenizer, base_model_name, all_file_locations[0])
+                modelfile = create_ollama_modelfile(
+                    tokenizer, base_model_name, all_file_locations[0]
+                )
             if modelfile is not None:
                 if is_vlm_update:
                     modelfile_location = os.path.join(save_directory, "Modelfile")
@@ -1877,26 +2012,33 @@ def unsloth_save_pretrained_gguf(
     # Step 10: Show BOS token warning if applicable
     if fix_bos_token:
         logger.warning(
-            "Unsloth: ##### The current model auto adds a BOS token.\n"\
+            "Unsloth: ##### The current model auto adds a BOS token.\n"
             "Unsloth: ##### We removed it in GGUF's chat template for you."
         )
-    pass
 
     if is_vlm_update:
         print("\n")
-        print(f"Unsloth: example usage for Multimodal LLMs: llama-mtmd-cli -m {all_file_locations[0]} --mmproj {all_file_locations[-1]}")
+        print(
+            f"Unsloth: example usage for Multimodal LLMs: llama-mtmd-cli -m {all_file_locations[0]} --mmproj {all_file_locations[-1]}"
+        )
         print("Unsloth: load image inside llama.cpp runner: /image test_image.jpg")
         print("Unsloth: Prompt model to describe the image")
     else:
-        print(f'Unsloth: example usage for text only LLMs: llama-cli --model {all_file_locations[0]} -p "why is the sky blue?"')
+        print(
+            f'Unsloth: example usage for text only LLMs: llama-cli --model {all_file_locations[0]} -p "why is the sky blue?"'
+        )
     if ollama_success and is_vlm_update:
         print(f"Unsloth: Saved Ollama Modelfile to {modelfile_location}")
-        print("Unsloth: convert model to ollama format by running - ollama create model_name -f ./Modelfile - inside save directory.")
+        print(
+            "Unsloth: convert model to ollama format by running - ollama create model_name -f ./Modelfile - inside save directory."
+        )
     if ollama_success and not is_vlm_update:
         print("Unsloth: Saved Ollama Modelfile to current directory")
-        print("Unsloth: convert model to ollama format by running - ollama create model_name -f ./Modelfile - inside current directory.")
+        print(
+            "Unsloth: convert model to ollama format by running - ollama create model_name -f ./Modelfile - inside current directory."
+        )
 
-    #Return a dict with all needed info for push_to_hub
+    # Return a dict with all needed info for push_to_hub
     return {
         "save_directory": save_directory,
         "gguf_files": all_file_locations,
@@ -1905,52 +2047,51 @@ def unsloth_save_pretrained_gguf(
         "is_vlm": is_vlm_update,
         "fix_bos_token": fix_bos_token,
     }
-pass
 
 
 def unsloth_push_to_hub_gguf(
     self,
-    repo_id              : str,
-    tokenizer            = None,
-    quantization_method  = "fast_quantized",
-    first_conversion     : str = None,
-    use_temp_dir         : Optional[bool] = None,
-    commit_message       : Optional[str] = "Trained with Unsloth",
-    private              : Optional[bool] = None,
-    token                : Union[bool, str, None] = None,
-    max_shard_size       : Union[int, str, None] = "5GB",
-    create_pr            : bool = False,
-    safe_serialization   : bool = True,
-    revision             : str = None,
-    commit_description   : str = "Upload model trained with Unsloth 2x faster",
-    tags                 : Optional[List[str]] = None,
-    temporary_location   : str = "_unsloth_temporary_saved_buffers",
-    maximum_memory_usage : float = 0.85,
+    repo_id: str,
+    tokenizer = None,
+    quantization_method = "fast_quantized",
+    first_conversion: str = None,
+    use_temp_dir: Optional[bool] = None,
+    commit_message: Optional[str] = "Trained with Unsloth",
+    private: Optional[bool] = None,
+    token: Union[bool, str, None] = None,
+    max_shard_size: Union[int, str, None] = "5GB",
+    create_pr: bool = False,
+    safe_serialization: bool = True,
+    revision: str = None,
+    commit_description: str = "Upload model trained with Unsloth 2x faster",
+    tags: Optional[List[str]] = None,
+    temporary_location: str = "_unsloth_temporary_saved_buffers",
+    maximum_memory_usage: float = 0.85,
 ):
     """
-        Same as .push_to_hub(...) except 4bit weights are auto
-        converted to float16 then converted to GGUF / llama.cpp format.
-
-        Choose for `quantization_method` to be:
-        "not_quantized"  : "Recommended. Fast conversion. Slow inference, big files.",
-        "fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
-        "quantized"      : "Recommended. Slow conversion. Fast inference, small files.",
-        "f32"     : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
-        "f16"     : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
-        "q8_0"    : "Fast conversion. High resource use, but generally acceptable.",
-        "q4_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
-        "q5_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
-        "q2_k"    : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
-        "q3_k_l"  : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
-        "q3_k_m"  : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
-        "q3_k_s"  : "Uses Q3_K for all tensors",
-        "q4_0"    : "Original quant method, 4-bit.",
-        "q4_1"    : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
-        "q4_k_s"  : "Uses Q4_K for all tensors",
-        "q5_0"    : "Higher accuracy, higher resource usage and slower inference.",
-        "q5_1"    : "Even higher accuracy, resource usage and slower inference.",
-        "q5_k_s"  : "Uses Q5_K for all tensors",
-        "q6_k"    : "Uses Q8_K for all tensors",
+    Same as .push_to_hub(...) except 4bit weights are auto
+    converted to float16 then converted to GGUF / llama.cpp format.
+
+    Choose for `quantization_method` to be:
+    "not_quantized"  : "Recommended. Fast conversion. Slow inference, big files.",
+    "fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
+    "quantized"      : "Recommended. Slow conversion. Fast inference, small files.",
+    "f32"     : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
+    "f16"     : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
+    "q8_0"    : "Fast conversion. High resource use, but generally acceptable.",
+    "q4_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
+    "q5_k_m"  : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
+    "q2_k"    : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
+    "q3_k_l"  : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+    "q3_k_m"  : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+    "q3_k_s"  : "Uses Q3_K for all tensors",
+    "q4_0"    : "Original quant method, 4-bit.",
+    "q4_1"    : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
+    "q4_k_s"  : "Uses Q4_K for all tensors",
+    "q5_0"    : "Higher accuracy, higher resource usage and slower inference.",
+    "q5_1"    : "Even higher accuracy, resource usage and slower inference.",
+    "q5_k_s"  : "Uses Q5_K for all tensors",
+    "q6_k"    : "Uses Q8_K for all tensors",
     """
     if tokenizer is None:
         raise ValueError("Unsloth: Saving to GGUF must have a tokenizer.")
@@ -1960,7 +2101,8 @@ def unsloth_push_to_hub_gguf(
 
     if use_temp_dir or use_temp_dir is None:
         import tempfile
-        temp_dir = tempfile.mkdtemp(prefix="unsloth_gguf_")
+
+        temp_dir = tempfile.mkdtemp(prefix = "unsloth_gguf_")
         save_directory = temp_dir
         cleanup_temp = True
     else:
@@ -1973,17 +2115,17 @@ def unsloth_push_to_hub_gguf(
     try:
         # Call save_pretrained_gguf - it returns all the info we need
         result = unsloth_save_pretrained_gguf(
-            self=self,
-            save_directory=save_directory,
-            tokenizer=tokenizer,
-            quantization_method=quantization_method,
-            first_conversion=first_conversion,
-            push_to_hub=False,  # Never push from here
-            token=None,  # Don't need token for local save
-            max_shard_size=max_shard_size,
-            safe_serialization=safe_serialization,
-            temporary_location=temporary_location,
-            maximum_memory_usage=maximum_memory_usage,
+            self = self,
+            save_directory = save_directory,
+            tokenizer = tokenizer,
+            quantization_method = quantization_method,
+            first_conversion = first_conversion,
+            push_to_hub = False,  # Never push from here
+            token = None,  # Don't need token for local save
+            max_shard_size = max_shard_size,
+            safe_serialization = safe_serialization,
+            temporary_location = temporary_location,
+            maximum_memory_usage = maximum_memory_usage,
         )
 
         # Extract results
@@ -1997,19 +2139,20 @@ def unsloth_push_to_hub_gguf(
     except Exception as e:
         if cleanup_temp:
             import shutil
+
             try:
                 shutil.rmtree(save_directory)
             except:
                 pass
         raise RuntimeError(f"Failed to convert model to GGUF: {e}")
-    pass
 
     # Step 3: Upload to HuggingFace Hub
     print("Unsloth: Uploading GGUF to Huggingface Hub...")
 
     try:
         from huggingface_hub import HfApi
-        api = HfApi(token=token)
+
+        api = HfApi(token = token)
 
         # Get full repo id
         if "/" not in repo_id:
@@ -2020,10 +2163,10 @@ def unsloth_push_to_hub_gguf(
 
         # Create repo
         api.create_repo(
-            repo_id=full_repo_id,
-            repo_type="model",
-            private=private,
-            exist_ok=True,
+            repo_id = full_repo_id,
+            repo_type = "model",
+            private = private,
+            exist_ok = True,
         )
 
         # Upload GGUF files
@@ -2032,50 +2175,55 @@ def unsloth_push_to_hub_gguf(
             # Replace temp directory name with proper model name
             if cleanup_temp and "unsloth_gguf_" in original_name:
                 # Extract the quantization part (e.g., ".Q8_0.gguf" or ".Q8_0-mmproj.gguf")
-                quant_suffix = original_name.split(".", 1)[1] if "." in original_name else original_name
+                quant_suffix = (
+                    original_name.split(".", 1)[1]
+                    if "." in original_name
+                    else original_name
+                )
                 proper_name = f"{model_name}.{quant_suffix}"
             else:
-                proper_name = original_name.replace(os.path.basename(save_directory), model_name)
+                proper_name = original_name.replace(
+                    os.path.basename(save_directory), model_name
+                )
 
             print(f"Uploading {proper_name}...")
 
             api.upload_file(
-                path_or_fileobj=file_location,
-                path_in_repo=proper_name,
-                repo_id=full_repo_id,
-                repo_type="model",
-                commit_message=commit_message,
-                commit_description=commit_description,
-                create_pr=create_pr,
-                revision=revision,
+                path_or_fileobj = file_location,
+                path_in_repo = proper_name,
+                repo_id = full_repo_id,
+                repo_type = "model",
+                commit_message = commit_message,
+                commit_description = commit_description,
+                create_pr = create_pr,
+                revision = revision,
             )
-        pass
 
         # Upload config.json if exists
         config_path = os.path.join(actual_save_directory, "config.json")
         if os.path.exists(config_path):
             print("Uploading config.json...")
             api.upload_file(
-                path_or_fileobj=config_path,
-                path_in_repo="config.json",
-                repo_id=full_repo_id,
-                repo_type="model",
-                commit_message=f"{commit_message} - config",
-                create_pr=create_pr,
-                revision=revision,
+                path_or_fileobj = config_path,
+                path_in_repo = "config.json",
+                repo_id = full_repo_id,
+                repo_type = "model",
+                commit_message = f"{commit_message} - config",
+                create_pr = create_pr,
+                revision = revision,
             )
 
         # Upload Modelfile if exists
         if modelfile_location and os.path.exists(modelfile_location):
             print("Uploading Ollama Modelfile...")
             api.upload_file(
-                path_or_fileobj=modelfile_location,
-                path_in_repo="Modelfile",
-                repo_id=full_repo_id,
-                repo_type="model",
-                commit_message=f"{commit_message} - Ollama Modelfile",
-                create_pr=create_pr,
-                revision=revision,
+                path_or_fileobj = modelfile_location,
+                path_in_repo = "Modelfile",
+                repo_id = full_repo_id,
+                repo_type = "model",
+                commit_message = f"{commit_message} - Ollama Modelfile",
+                create_pr = create_pr,
+                revision = revision,
             )
 
         # Create and upload README
@@ -2101,10 +2249,16 @@ def unsloth_push_to_hub_gguf(
             # Fix filename in README too
             original_name = os.path.basename(file)
             if cleanup_temp and "unsloth_gguf_" in original_name:
-                quant_suffix = original_name.split(".", 1)[1] if "." in original_name else original_name
+                quant_suffix = (
+                    original_name.split(".", 1)[1]
+                    if "." in original_name
+                    else original_name
+                )
                 proper_name = f"{model_name}.{quant_suffix}"
             else:
-                proper_name = original_name.replace(os.path.basename(save_directory), model_name)
+                proper_name = original_name.replace(
+                    os.path.basename(save_directory), model_name
+                )
             readme_content += f"- `{proper_name}`\n"
 
         # Special note for VLM with Modelfile
@@ -2115,31 +2269,36 @@ def unsloth_push_to_hub_gguf(
             readme_content += "1. Place the `Modelfile` in the same directory as the finetuned bf16 merged model\n"
             readme_content += "3. Run: `ollama create model_name -f ./Modelfile`\n"
             readme_content += "   (Replace `model_name` with your desired name)\n\n"
-            readme_content += "This will create a unified bf16 model that Ollama can use.\n"
+            readme_content += (
+                "This will create a unified bf16 model that Ollama can use.\n"
+            )
         elif modelfile_location:
             readme_content += "\n## Ollama\n"
             readme_content += "An Ollama Modelfile is included for easy deployment.\n"
 
-
         if fix_bos_token:
             readme_content += "\n## Note\n"
-            readme_content += "The model's BOS token behavior was adjusted for GGUF compatibility.\n"
+            readme_content += (
+                "The model's BOS token behavior was adjusted for GGUF compatibility.\n"
+            )
 
         readme_path = os.path.join(actual_save_directory, "README.md")
         with open(readme_path, "w") as f:
             f.write(readme_content)
 
         api.upload_file(
-            path_or_fileobj=readme_path,
-            path_in_repo="README.md",
-            repo_id=full_repo_id,
-            repo_type="model",
-            commit_message="Add README",
-            create_pr=create_pr,
-            revision=revision,
+            path_or_fileobj = readme_path,
+            path_in_repo = "README.md",
+            repo_id = full_repo_id,
+            repo_type = "model",
+            commit_message = "Add README",
+            create_pr = create_pr,
+            revision = revision,
         )
 
-        print(f"Unsloth: Successfully uploaded GGUF to https://huggingface.co/{full_repo_id}")
+        print(
+            f"Unsloth: Successfully uploaded GGUF to https://huggingface.co/{full_repo_id}"
+        )
 
         # Add tags
         if tags is None:
@@ -2150,9 +2309,9 @@ def unsloth_push_to_hub_gguf(
 
         try:
             api.add_tags(
-                repo_id=full_repo_id,
-                tags=tags,
-                repo_type="model",
+                repo_id = full_repo_id,
+                tags = tags,
+                repo_type = "model",
             )
         except:
             pass
@@ -2165,29 +2324,30 @@ def unsloth_push_to_hub_gguf(
         if cleanup_temp and os.path.exists(save_directory):
             print("Unsloth: Cleaning up temporary files...")
             import shutil
+
             try:
                 shutil.rmtree(save_directory)
             except:
                 pass
 
     return full_repo_id
-pass
 
 
 # Corrected function to save LoRA to a custom directory
 def save_lora_to_custom_dir(model, tokenizer, save_directory):
     # Create the custom directory if it doesn't exist
-    os.makedirs(save_directory, exist_ok=True)
+    os.makedirs(save_directory, exist_ok = True)
 
     # Call the unsloth_save_model function with the custom directory
     unsloth_save_model(
         model,
         tokenizer,
-        save_directory=save_directory,
-        save_method="lora",
-        push_to_hub=False,
+        save_directory = save_directory,
+        save_method = "lora",
+        push_to_hub = False,
     )
 
+
 # Corrected method within the model class to convert LoRA to GGML and push to Hugging Face Hub
 def unsloth_convert_lora_to_ggml_and_push_to_hub(
     self,
@@ -2207,7 +2367,7 @@ def unsloth_convert_lora_to_ggml_and_push_to_hub(
         if IS_KAGGLE_ENVIRONMENT:
             python_install = install_python_non_blocking(["protobuf"])
             python_install.wait()
-            install_llama_cpp_blocking(use_cuda=False)
+            install_llama_cpp_blocking(use_cuda = False)
             makefile = None
         else:
             git_clone = install_llama_cpp_clone_non_blocking()
@@ -2227,17 +2387,26 @@ def unsloth_convert_lora_to_ggml_and_push_to_hub(
     model_type = self.config.model_type
     output_file = os.path.join(lora_directory_push, "ggml-adapter-model.bin")
 
-    print(f"Unsloth: Converting auto-saved LoRA adapters at {lora_directory_push} to GGML format.")
+    print(
+        f"Unsloth: Converting auto-saved LoRA adapters at {lora_directory_push} to GGML format."
+    )
     print(f"The output file will be {output_file}")
 
     command = f"python3 llama.cpp/convert-lora-to-ggml.py {lora_directory_push} {output_file} llama"
 
     try:
-        with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp:
+        with subprocess.Popen(
+            command,
+            shell = True,
+            stdout = subprocess.PIPE,
+            stderr = subprocess.PIPE,
+            bufsize = 1,
+            universal_newlines = True,
+        ) as sp:
             for line in sp.stdout:
-                print(line, end="", flush=True)
+                print(line, end = "", flush = True)
             for line in sp.stderr:
-                print(line, end="", flush=True)
+                print(line, end = "", flush = True)
             sp.wait()
             if sp.returncode != 0:
                 raise subprocess.CalledProcessError(sp.returncode, command)
@@ -2249,17 +2418,26 @@ def unsloth_convert_lora_to_ggml_and_push_to_hub(
 
     print("Unsloth: Uploading GGML file to Hugging Face Hub...")
     username = upload_to_huggingface(
-        self, repo_id, token,
-        "GGML converted LoRA", "ggml", output_file, None, private,
+        self,
+        repo_id,
+        token,
+        "GGML converted LoRA",
+        "ggml",
+        output_file,
+        None,
+        private,
     )
     link = f"{repo_id.lstrip('/')}"
     print("Unsloth: Done.")
     print(f"Converted LoRA to GGML and uploaded to https://huggingface.co/{link}")
-    print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!")
+    print(
+        "\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!"
+    )
+
 
 def unsloth_convert_lora_to_ggml_and_save_locally(
     self,
-    save_directory: str, # Added parameter for the folder name
+    save_directory: str,  # Added parameter for the folder name
     tokenizer,
     temporary_location: str = "_unsloth_temporary_saved_buffers",
     maximum_memory_usage: float = 0.85,
@@ -2268,7 +2446,7 @@ def unsloth_convert_lora_to_ggml_and_save_locally(
         if IS_KAGGLE_ENVIRONMENT:
             python_install = install_python_non_blocking(["protobuf"])
             python_install.wait()
-            install_llama_cpp_blocking(use_cuda=False)
+            install_llama_cpp_blocking(use_cuda = False)
             makefile = None
         else:
             git_clone = install_llama_cpp_clone_non_blocking()
@@ -2288,17 +2466,26 @@ def unsloth_convert_lora_to_ggml_and_save_locally(
     model_type = self.config.model_type
     output_file = os.path.join(save_directory, "ggml-adapter-model.bin")
 
-    print(f"Unsloth: Converting auto-saved LoRA adapters at {save_directory} to GGML format.")
+    print(
+        f"Unsloth: Converting auto-saved LoRA adapters at {save_directory} to GGML format."
+    )
     print(f"The output file will be {output_file}")
 
     command = f"python3 llama.cpp/convert-lora-to-ggml.py {save_directory} {output_file} llama"
 
     try:
-        with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp:
+        with subprocess.Popen(
+            command,
+            shell = True,
+            stdout = subprocess.PIPE,
+            stderr = subprocess.PIPE,
+            bufsize = 1,
+            universal_newlines = True,
+        ) as sp:
             for line in sp.stdout:
-                print(line, end="", flush=True)
+                print(line, end = "", flush = True)
             for line in sp.stderr:
-                print(line, end="", flush=True)
+                print(line, end = "", flush = True)
             sp.wait()
             if sp.returncode != 0:
                 raise subprocess.CalledProcessError(sp.returncode, command)
@@ -2307,8 +2494,9 @@ def unsloth_convert_lora_to_ggml_and_save_locally(
         return
     print("Unsloth: Done.")
     print(f"Unsloth: Conversion completed! Output file: {output_file}")
-    print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!")
-pass
+    print(
+        "\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!"
+    )
 
 
 from .models.loader_utils import get_model_name
@@ -2321,6 +2509,7 @@ def unsloth_convert_lora_to_ggml_and_save_locally(
     convert_to_gguf as _convert_to_gguf,
 )
 
+
 @torch.inference_mode
 def save_to_gguf_generic(
     model,
@@ -2331,32 +2520,41 @@ def save_to_gguf_generic(
     repo_id = None,
     token = None,
 ):
-    if token is None and repo_id is not None: token = get_token()
+    if token is None and repo_id is not None:
+        token = get_token()
     if repo_id is not None and token is None:
         raise RuntimeError("Unsloth: Please specify a token for uploading!")
 
     if not os.path.exists(os.path.join("llama.cpp", "unsloth_convert_hf_to_gguf.py")):
         install_llama_cpp(just_clone_repo = True)
-    pass
 
     # Use old style quantization_method
     new_quantization_methods = []
     if quantization_method is not None:
         # Convert quantization_method to list
-        if   isinstance(quantization_method, list):  pass
-        elif isinstance(quantization_method, str):   quantization_method = [ quantization_method, ]
-        elif isinstance(quantization_method, tuple): quantization_method = list(quantization_method)
+        if isinstance(quantization_method, list):
+            pass
+        elif isinstance(quantization_method, str):
+            quantization_method = [
+                quantization_method,
+            ]
+        elif isinstance(quantization_method, tuple):
+            quantization_method = list(quantization_method)
         else:
-            raise TypeError("Unsloth: quantization_method can only be a string or a list of strings")
-        pass
+            raise TypeError(
+                "Unsloth: quantization_method can only be a string or a list of strings"
+            )
         for i, quant_method in enumerate(quantization_method):
             quant_method = quant_method.lower()
-            if   quant_method == "not_quantized":  quant_method = "f16"
-            elif quant_method == "fast_quantized": quant_method = "q8_0"
-            elif quant_method == "quantized":      quant_method = "q4_k_m"
-            elif quant_method is None:             quant_method = "q8_0"
+            if quant_method == "not_quantized":
+                quant_method = "f16"
+            elif quant_method == "fast_quantized":
+                quant_method = "q8_0"
+            elif quant_method == "quantized":
+                quant_method = "q4_k_m"
+            elif quant_method is None:
+                quant_method = "q8_0"
             new_quantization_methods.append(quant_method.lower())
-        pass
     else:
         new_quantization_methods.append(quantization_type.lower())
     # Check if wrong method
@@ -2366,8 +2564,6 @@ def save_to_gguf_generic(
             for key, value in ALLOWED_QUANTS.items():
                 error += f"[{key}] => {value}\n"
             raise RuntimeError(error)
-        pass
-    pass
 
     # Go through all types and save individually - somewhat inefficient
     # since we save F16 / BF16 multiple times
@@ -2388,6 +2584,7 @@ def save_to_gguf_generic(
             )
 
             from huggingface_hub import HfApi
+
             api = HfApi(token = token)
             api.upload_folder(
                 folder_path = save_directory,
@@ -2395,48 +2592,44 @@ def save_to_gguf_generic(
                 repo_type = "model",
                 allow_patterns = ["*.gguf"],
             )
-        pass
-    pass
     return metadata
-pass
 
 
 @torch.inference_mode
 def unsloth_generic_save(
     model,
     tokenizer,
-    save_directory       : Union[str, os.PathLike] = "unsloth_finetuned_merge",
-    save_method          : str = "lora", # ["lora", "merged_16bit", "merged_4bit"]
-    push_to_hub          : bool = False,
-    token                : Optional[Union[str, bool]] = None,
-    is_main_process      : bool = True,
-    state_dict           : Optional[dict] = None,
-    save_function        : Callable = torch.save,
-    max_shard_size       : Union[int, str] = "5GB",
-    safe_serialization   : bool = True,
-    variant              : Optional[str] = None,
-    save_peft_format     : bool = True,
-
+    save_directory: Union[str, os.PathLike] = "unsloth_finetuned_merge",
+    save_method: str = "lora",  # ["lora", "merged_16bit", "merged_4bit"]
+    push_to_hub: bool = False,
+    token: Optional[Union[str, bool]] = None,
+    is_main_process: bool = True,
+    state_dict: Optional[dict] = None,
+    save_function: Callable = torch.save,
+    max_shard_size: Union[int, str] = "5GB",
+    safe_serialization: bool = True,
+    variant: Optional[str] = None,
+    save_peft_format: bool = True,
     # Push to hub
-    use_temp_dir         : Optional[bool] = None,
-    commit_message       : Optional[str] = "Trained with Unsloth",
-    private              : Optional[bool] = None,
-    create_pr            : bool = False,
-    revision             : str = None,
-    commit_description   : str = "Upload model trained with Unsloth 2x faster",
-    tags                 : List[str] = None,
-
+    use_temp_dir: Optional[bool] = None,
+    commit_message: Optional[str] = "Trained with Unsloth",
+    private: Optional[bool] = None,
+    create_pr: bool = False,
+    revision: str = None,
+    commit_description: str = "Upload model trained with Unsloth 2x faster",
+    tags: List[str] = None,
     # Our functions
-    temporary_location   : str = "_unsloth_temporary_saved_buffers",
-    maximum_memory_usage : float = 0.9,
+    temporary_location: str = "_unsloth_temporary_saved_buffers",
+    maximum_memory_usage: float = 0.9,
 ):
-    if token is None and push_to_hub: token = get_token()
+    if token is None and push_to_hub:
+        token = get_token()
 
     if save_method == "merged_4bit":
         raise RuntimeError(
-            "Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\n"\
-            "to merge to GGUF or others later on. I suggest you to do this as a final step\n"\
-            "if you're planning to do multiple saves.\n"\
+            "Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\n"
+            "to merge to GGUF or others later on. I suggest you to do this as a final step\n"
+            "if you're planning to do multiple saves.\n"
             "If you are certain, change `save_method` to `merged_4bit_forced`."
         )
     elif save_method == "merged_4bit_forced":
@@ -2444,54 +2637,52 @@ def unsloth_generic_save(
 
     merge_and_overwrite_lora(
         get_model_name,
-        model                = model,
-        tokenizer            = tokenizer,
-        save_directory       = save_directory,
-        push_to_hub          = push_to_hub,
-        private              = private,
-        token                = token,
-        save_method          = save_method,
-        output_dtype         = None,
+        model = model,
+        tokenizer = tokenizer,
+        save_directory = save_directory,
+        push_to_hub = push_to_hub,
+        private = private,
+        token = token,
+        save_method = save_method,
+        output_dtype = None,
         low_disk_space_usage = True,
-        use_temp_file        = False,
+        use_temp_file = False,
     )
     return
-pass
 
 
 def unsloth_generic_save_pretrained_merged(
     self,
-    save_directory       : Union[str, os.PathLike],
-    tokenizer            = None,
-    save_method          : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
-    push_to_hub          : bool = False,
-    token                : Optional[Union[str, bool]] = None,
-    is_main_process      : bool = True,
-    state_dict           : Optional[dict] = None,
-    save_function        : Callable = torch.save,
-    max_shard_size       : Union[int, str] = "5GB",
-    safe_serialization   : bool = True,
-    variant              : Optional[str] = None,
-    save_peft_format     : bool = True,
-    tags                 : List[str] = None,
-    temporary_location   : str = "_unsloth_temporary_saved_buffers",
-    maximum_memory_usage : float = 0.75,
+    save_directory: Union[str, os.PathLike],
+    tokenizer = None,
+    save_method: str = "merged_16bit",  # ["lora", "merged_16bit", "merged_4bit"]
+    push_to_hub: bool = False,
+    token: Optional[Union[str, bool]] = None,
+    is_main_process: bool = True,
+    state_dict: Optional[dict] = None,
+    save_function: Callable = torch.save,
+    max_shard_size: Union[int, str] = "5GB",
+    safe_serialization: bool = True,
+    variant: Optional[str] = None,
+    save_peft_format: bool = True,
+    tags: List[str] = None,
+    temporary_location: str = "_unsloth_temporary_saved_buffers",
+    maximum_memory_usage: float = 0.75,
 ):
     """
-        Same as .push_to_hub(...) except 4bit weights are auto
-        converted to float16 with as few overhead as possible.
+    Same as .push_to_hub(...) except 4bit weights are auto
+    converted to float16 with as few overhead as possible.
 
-        Choose for `save_method` to be either:
-        1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
-        2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
-        3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.
+    Choose for `save_method` to be either:
+    1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
+    2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
+    3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.
     """
     if tokenizer is None:
         logger.warning_once(
-            "Unsloth: You're not saving a tokenizer as well?\n"\
+            "Unsloth: You're not saving a tokenizer as well?\n"
             "You can do it separately via `tokenizer.save_pretrained(...)`"
         )
-    pass
 
     arguments = dict(locals())
     arguments["model"] = self
@@ -2499,62 +2690,59 @@ def unsloth_generic_save_pretrained_merged(
     unsloth_generic_save(**arguments)
     for _ in range(3):
         gc.collect()
-pass
 
 
 def unsloth_generic_push_to_hub_merged(
     self,
-    repo_id              : str,
-    tokenizer            = None,
-    save_method          : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
-    use_temp_dir         : Optional[bool] = None,
-    commit_message       : Optional[str] = "Trained with Unsloth",
-    private              : Optional[bool] = None,
-    token                : Union[bool, str, None] = None,
-    max_shard_size       : Union[int, str, None] = "5GB",
-    create_pr            : bool = False,
-    safe_serialization   : bool = True,
-    revision             : str = None,
-    commit_description   : str = "Upload model trained with Unsloth 2x faster",
-    tags                 : Optional[List[str]] = None,
-    temporary_location   : str = "_unsloth_temporary_saved_buffers",
-    maximum_memory_usage : float = 0.75,
+    repo_id: str,
+    tokenizer = None,
+    save_method: str = "merged_16bit",  # ["lora", "merged_16bit", "merged_4bit"]
+    use_temp_dir: Optional[bool] = None,
+    commit_message: Optional[str] = "Trained with Unsloth",
+    private: Optional[bool] = None,
+    token: Union[bool, str, None] = None,
+    max_shard_size: Union[int, str, None] = "5GB",
+    create_pr: bool = False,
+    safe_serialization: bool = True,
+    revision: str = None,
+    commit_description: str = "Upload model trained with Unsloth 2x faster",
+    tags: Optional[List[str]] = None,
+    temporary_location: str = "_unsloth_temporary_saved_buffers",
+    maximum_memory_usage: float = 0.75,
 ):
     """
-        Same as .push_to_hub(...) except 4bit weights are auto
-        converted to float16 with as few overhead as possible.
+    Same as .push_to_hub(...) except 4bit weights are auto
+    converted to float16 with as few overhead as possible.
 
-        Choose for `save_method` to be either:
-        1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
-        2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
-        3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.
+    Choose for `save_method` to be either:
+    1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
+    2.  `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
+    3.  `lora`: Save LoRA adapters with no merging. Useful for HF inference.
     """
     if tokenizer is None:
         logger.warning_once(
-            "Unsloth: You're not saving a tokenizer as well?\n"\
+            "Unsloth: You're not saving a tokenizer as well?\n"
             "You can do it separately via `tokenizer.push_to_hub(...)`"
         )
-    pass
 
     arguments = dict(locals())
-    arguments["model"]          = self
+    arguments["model"] = self
     arguments["save_directory"] = repo_id
-    arguments["push_to_hub"]    = True
+    arguments["push_to_hub"] = True
     del arguments["self"]
     del arguments["repo_id"]
     unsloth_generic_save(**arguments)
     for _ in range(3):
         gc.collect()
-pass
 
 
 def unsloth_save_pretrained_torchao(
     self,
-    save_directory       : Union[str, os.PathLike],
-    tokenizer            = None,
-    torchao_config       = None,
-    push_to_hub          : bool = False,
-    token                : Optional[Union[str, bool]] = None,
+    save_directory: Union[str, os.PathLike],
+    tokenizer = None,
+    torchao_config = None,
+    push_to_hub: bool = False,
+    token: Optional[Union[str, bool]] = None,
 ):
     """Quantizes the model with torchao and saves a torchao quantized checkpoint
 
@@ -2565,14 +2753,15 @@ def unsloth_save_pretrained_torchao(
     """
     # first merge the lora weights
     arguments = dict(locals())
-    arguments["model"]       = self
-    arguments["tokenizer"]   = tokenizer
-    arguments["push_to_hub"] = False # We save ourselves
-    arguments["save_method"] = "merged_16bit" # Must be 16bit
+    arguments["model"] = self
+    arguments["tokenizer"] = tokenizer
+    arguments["push_to_hub"] = False  # We save ourselves
+    arguments["save_method"] = "merged_16bit"  # Must be 16bit
     del arguments["self"]
     del arguments["torchao_config"]
 
-    if token is None and push_to_hub: token = get_token()
+    if token is None and push_to_hub:
+        token = get_token()
 
     if not isinstance(self, PeftModelForCausalLM) and not isinstance(self, PeftModel):
         self.save_pretrained(save_directory)
@@ -2584,9 +2773,13 @@ def unsloth_save_pretrained_torchao(
 
     from transformers import AutoModel, AutoTokenizer, TorchAoConfig
     from torchao import quantize_
+
     if torchao_config is None:
         from torchao.quantization import Int8DynamicActivationInt8WeightConfig
-        print("Unsloth: You did not specify a `torchao_config`, so defaulting to `Int8DynamicActivationInt8WeightConfig`")
+
+        print(
+            "Unsloth: You did not specify a `torchao_config`, so defaulting to `Int8DynamicActivationInt8WeightConfig`"
+        )
         torchao_config = Int8DynamicActivationInt8WeightConfig()
     quantization_config = TorchAoConfig(quant_type = torchao_config)
 
@@ -2594,9 +2787,9 @@ def unsloth_save_pretrained_torchao(
 
     # TorchAO must only use bfloat16 for loading (float16 fails)
     if HAS_TORCH_DTYPE:
-        kwargs = {"torch_dtype" : torch.bfloat16}
+        kwargs = {"torch_dtype": torch.bfloat16}
     else:
-        kwargs = {"dtype" : torch.bfloat16}
+        kwargs = {"dtype": torch.bfloat16}
     model = AutoModel.from_pretrained(
         arguments["save_directory"],
         device_map = "auto",
@@ -2610,21 +2803,25 @@ def unsloth_save_pretrained_torchao(
     safe_serialization = Version(importlib_version("torchao")) > Version("0.14.0")
     safe_serialization = False
     if push_to_hub:
-        if token is None and push_to_hub: token = get_token()
-        model.push_to_hub(torchao_save_directory, safe_serialization = safe_serialization, token = token)
+        if token is None and push_to_hub:
+            token = get_token()
+        model.push_to_hub(
+            torchao_save_directory, safe_serialization = safe_serialization, token = token
+        )
         tokenizer.push_to_hub(torchao_save_directory, token = token)
     else:
-        model.save_pretrained(torchao_save_directory, safe_serialization = safe_serialization)
+        model.save_pretrained(
+            torchao_save_directory, safe_serialization = safe_serialization
+        )
         tokenizer.save_pretrained(torchao_save_directory)
-    pass
     for _ in range(3):
         gc.collect()
-pass
 
 
 def not_implemented_save(*args, **kwargs):
-    raise NotImplementedError("Unsloth: Sorry GGUF is currently not supported for vision models!")
-pass
+    raise NotImplementedError(
+        "Unsloth: Sorry GGUF is currently not supported for vision models!"
+    )
 
 
 def patch_saving_functions(model, vision = False):
@@ -2637,7 +2834,6 @@ def patch_saving_functions(model, vision = False):
         original_push_to_hub = model.original_push_to_hub
     else:
         original_push_to_hub = model.push_to_hub
-    pass
 
     signature = str(inspect.signature(original_push_to_hub)).replace("NoneType", "None")
     signature = signature[1:]
@@ -2702,38 +2898,59 @@ def patch_saving_functions(model, vision = False):
 
     original_model = model
     while True:
-
         if original_model.push_to_hub.__name__ != "unsloth_push_to_hub":
             original_model.original_push_to_hub = original_model.push_to_hub
-            original_model.push_to_hub = types.MethodType(unsloth_push_to_hub, original_model)
+            original_model.push_to_hub = types.MethodType(
+                unsloth_push_to_hub, original_model
+            )
             if hasattr(original_model, "add_model_tags"):
-                original_model.add_model_tags(["unsloth",])
-            pass
-        pass
+                original_model.add_model_tags(
+                    [
+                        "unsloth",
+                    ]
+                )
 
-        if hasattr(original_model, "model"): original_model = original_model.model
-        else: break
-    pass
+        if hasattr(original_model, "model"):
+            original_model = original_model.model
+        else:
+            break
 
     # Add saving methods to top level model
     if not vision:
         if hasattr(model, "config"):
             # Counteract tokenizers
-            model.push_to_hub_merged      = types.MethodType(unsloth_generic_push_to_hub_merged,            model)
-            model.save_pretrained_merged  = types.MethodType(unsloth_generic_save_pretrained_merged,        model)
-            model.push_to_hub_gguf        = types.MethodType(unsloth_push_to_hub_gguf,                      model)
-            model.save_pretrained_gguf    = types.MethodType(unsloth_save_pretrained_gguf,                  model)
-            model.save_pretrained_torchao = types.MethodType(unsloth_save_pretrained_torchao,               model)
-            model.push_to_hub_ggml        = types.MethodType(unsloth_convert_lora_to_ggml_and_push_to_hub,  model)
-            model.save_pretrained_ggml    = types.MethodType(unsloth_convert_lora_to_ggml_and_save_locally, model)
-        pass
+            model.push_to_hub_merged = types.MethodType(
+                unsloth_generic_push_to_hub_merged, model
+            )
+            model.save_pretrained_merged = types.MethodType(
+                unsloth_generic_save_pretrained_merged, model
+            )
+            model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model)
+            model.save_pretrained_gguf = types.MethodType(
+                unsloth_save_pretrained_gguf, model
+            )
+            model.save_pretrained_torchao = types.MethodType(
+                unsloth_save_pretrained_torchao, model
+            )
+            model.push_to_hub_ggml = types.MethodType(
+                unsloth_convert_lora_to_ggml_and_push_to_hub, model
+            )
+            model.save_pretrained_ggml = types.MethodType(
+                unsloth_convert_lora_to_ggml_and_save_locally, model
+            )
     else:
         # Vision only 1 option
-        model.push_to_hub_merged      = types.MethodType(unsloth_generic_push_to_hub_merged,     model)
-        model.save_pretrained_merged  = types.MethodType(unsloth_generic_save_pretrained_merged, model)
-        model.push_to_hub_gguf        = types.MethodType(unsloth_push_to_hub_gguf,               model)
-        model.save_pretrained_gguf    = types.MethodType(unsloth_save_pretrained_gguf,           model)
-        model.save_pretrained_torchao = types.MethodType(unsloth_save_pretrained_torchao,        model)
-    pass
+        model.push_to_hub_merged = types.MethodType(
+            unsloth_generic_push_to_hub_merged, model
+        )
+        model.save_pretrained_merged = types.MethodType(
+            unsloth_generic_save_pretrained_merged, model
+        )
+        model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model)
+        model.save_pretrained_gguf = types.MethodType(
+            unsloth_save_pretrained_gguf, model
+        )
+        model.save_pretrained_torchao = types.MethodType(
+            unsloth_save_pretrained_torchao, model
+        )
     return model
-pass
diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py
index b35c1326c..bf8d45c89 100644
--- a/unsloth/tokenizer_utils.py
+++ b/unsloth/tokenizer_utils.py
@@ -44,10 +44,12 @@
 ]
 
 
-IGNORED_TOKENIZER_CHECKING = frozenset((
-    "CodeLlamaTokenizerFast",
-    "CodeLlamaTokenizer",
-))
+IGNORED_TOKENIZER_CHECKING = frozenset(
+    (
+        "CodeLlamaTokenizerFast",
+        "CodeLlamaTokenizer",
+    )
+)
 
 
 IGNORED_TOKENIZER_NAMES = [
@@ -56,26 +58,24 @@
     "unsloth/Qwen2.5-Coder-7B-Instruct",
 ]
 IGNORED_TOKENIZER_NAMES = frozenset(
-    [x.lower() for x in IGNORED_TOKENIZER_NAMES] + \
-    [x.lower()+"-bnb-4bit" for x in IGNORED_TOKENIZER_NAMES]
+    [x.lower() for x in IGNORED_TOKENIZER_NAMES]
+    + [x.lower() + "-bnb-4bit" for x in IGNORED_TOKENIZER_NAMES]
 )
 os.environ["UNSLOTH_IGNORED_TOKENIZER_NAMES"] = "\n".join(IGNORED_TOKENIZER_NAMES)
 
 # Check environments
 keynames = "\n" + "\n".join(os.environ.keys())
-IS_COLAB_ENVIRONMENT  = "\nCOLAB_"  in keynames
+IS_COLAB_ENVIRONMENT = "\nCOLAB_" in keynames
 IS_KAGGLE_ENVIRONMENT = "\nKAGGLE_" in keynames
 KAGGLE_TMP = "/tmp"
 del keynames
 
 
 def try_fix_tokenizer(tokenizer, prepend = True):
-
     if hasattr(tokenizer, "_tokenizer"):
         converted_tokenizer = tokenizer._tokenizer
     else:
         converted_tokenizer = convert_slow_tokenizer(tokenizer)
-    pass
 
     tokenizer_string = converted_tokenizer.to_str()
 
@@ -83,7 +83,6 @@ def try_fix_tokenizer(tokenizer, prepend = True):
     prepend_text = '{"type":"Prepend","prepend":"ā"},'
     if not prepend and prepend_text in tokenizer_string:
         tokenizer_string = tokenizer_string.replace(prepend_text, "", 1)
-    pass
 
     dir_names = dir(tokenizer)
     # Get eos_token, bos_token etc
@@ -91,19 +90,21 @@ def try_fix_tokenizer(tokenizer, prepend = True):
 
     for token_name in token_names:
         token = getattr(tokenizer, token_name, None)
-        if token is None: continue
+        if token is None:
+            continue
         token_id = getattr(tokenizer, token_name + "_id", None)
 
         # Locate the token's id mapping in the string
         find_text = f'"id":{token_id},"content":"'
         start = tokenizer_string.find(find_text) + len(find_text)
-        if start == -1: continue
-        end   = tokenizer_string.find('",', start)
+        if start == -1:
+            continue
+        end = tokenizer_string.find('",', start)
 
-        bad_token = tokenizer_string[start : end]
+        bad_token = tokenizer_string[start:end]
         # Check if token is the actual same one - if not, edit it
         if bad_token != token:
-            bad_text  = f'{find_text}{bad_token}",'
+            bad_text = f'{find_text}{bad_token}",'
             good_text = f'{find_text}{token}",'
             tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)
 
@@ -111,24 +112,20 @@ def try_fix_tokenizer(tokenizer, prepend = True):
             bad_text = f'"{bad_token}":{token_id},'
             good_text = f'"{token}":{token_id},'
             tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)
-        pass
-    pass
 
     fixed_tokenizer = converted_tokenizer.from_str(tokenizer_string)
     return fixed_tokenizer
-pass
 
 
 def get_sorted_dict(dictionary):
     sorted_keys = sorted(dictionary.values())
-    inverted_dictionary = { value : key for key, value in dictionary.items() }
+    inverted_dictionary = {value: key for key, value in dictionary.items()}
 
     sorted_dictionary = {}
     for key in sorted_keys:
         value = inverted_dictionary[key]
         sorted_dictionary[value] = key
     return sorted_dictionary
-pass
 
 
 def convert_to_fast_tokenizer(
@@ -136,13 +133,14 @@ def convert_to_fast_tokenizer(
     temporary_location = "_unsloth_sentencepiece_temp",
 ):
     is_fast = getattr(slow_tokenizer, "is_fast", False)
-    if is_fast: return slow_tokenizer
-    
+    if is_fast:
+        return slow_tokenizer
+
     try:
         tokenizer_name = slow_tokenizer.__class__.__name__
         lowered_tokenizer_name = tokenizer_name.lower()
         if lowered_tokenizer_name.endswith("tokenizer"):
-            class_name = lowered_tokenizer_name[:-len("tokenizer")]
+            class_name = lowered_tokenizer_name[: -len("tokenizer")]
             FastTokenizer = eval(
                 f'__import__(f"transformers.models.{class_name}").{tokenizer_name}Fast'
             )
@@ -150,52 +148,52 @@ def convert_to_fast_tokenizer(
             FastTokenizer = PreTrainedTokenizerFast
     except:
         FastTokenizer = PreTrainedTokenizerFast
-    pass
 
     # Get all arguments (bos_token, etc)
     docs = FastTokenizer.__doc__
-    docs = docs[docs.find("Args:"):]
+    docs = docs[docs.find("Args:") :]
     args = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
     args = [x for x in args if not x.endswith("_file")]
 
     # Also some missing maybe!
     docs = PreTrainedTokenizerFast.__doc__
-    docs = docs[docs.find("Args:"):]
+    docs = docs[docs.find("Args:") :]
     args2 = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
     args2 = [x for x in args2 if not x.endswith("_file")]
     args = list(set(args + args2))
 
     kwargs = {}
-    for arg in args: kwargs[arg] = getattr(slow_tokenizer, arg, None)
+    for arg in args:
+        kwargs[arg] = getattr(slow_tokenizer, arg, None)
     kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = True)
-    fast_tokenizer = FastTokenizer( **kwargs )
+    fast_tokenizer = FastTokenizer(**kwargs)
 
     # Check if they're similar!
     sorted_slow_tokenizer = get_sorted_dict(slow_tokenizer.get_vocab())
     sorted_fast_tokenizer = get_sorted_dict(fast_tokenizer.get_vocab())
 
-    check_vocab   = (sorted_slow_tokenizer == sorted_fast_tokenizer)
-    check_special = (slow_tokenizer.all_special_tokens == fast_tokenizer.all_special_tokens)
+    check_vocab = sorted_slow_tokenizer == sorted_fast_tokenizer
+    check_special = (
+        slow_tokenizer.all_special_tokens == fast_tokenizer.all_special_tokens
+    )
 
     # Failure so return slow_tokenizer
-    if not check_vocab or not check_special: return slow_tokenizer
+    if not check_vocab or not check_special:
+        return slow_tokenizer
 
     # Now confirm if they match
     if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
         # Maybe remove prepending of __apple?
         kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = False)
-        fast_tokenizer = FastTokenizer( **kwargs )
+        fast_tokenizer = FastTokenizer(**kwargs)
         if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
             # Failure :(
             return slow_tokenizer
-        pass
-    pass
 
     # Also tokenizer.model is missing!
     name = slow_tokenizer.name_or_path.replace("/", "_")
     if not os.path.exists(temporary_location):
         os.makedirs(temporary_location)
-    pass
     new_location = f"{temporary_location}/{name}"
     slow_tokenizer.save_pretrained(new_location)
     fast_tokenizer.save_pretrained(new_location)
@@ -205,66 +203,72 @@ def convert_to_fast_tokenizer(
     if assert_same_tokenization(slow_tokenizer, fast_tokenizer):
         return fast_tokenizer
     return slow_tokenizer
-pass
 
 
 # Check Mistral chat template without BOS / EOS
-mistral_template = \
-    "{% if messages[0]['role'] == 'system' %}"\
-        "{% if messages[1]['role'] == 'user' %}"\
-            "{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
-            "{% set loop_messages = messages[2:] %}"\
-        "{% else %}"\
-            "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
-            "{% set loop_messages = messages[1:] %}"\
-        "{% endif %}"\
-    "{% else %}"\
-        "{% set loop_messages = messages %}"\
-    "{% endif %}"\
-    "{% for message in loop_messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{ message['content'] }}"\
-        "{% else %}"\
-            "{{ raise_exception('Only user and assistant roles are supported!') }}"\
-        "{% endif %}"\
+mistral_template = (
+    "{% if messages[0]['role'] == 'system' %}"
+    "{% if messages[1]['role'] == 'user' %}"
+    "{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"
+    "{% set loop_messages = messages[2:] %}"
+    "{% else %}"
+    "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"
+    "{% set loop_messages = messages[1:] %}"
+    "{% endif %}"
+    "{% else %}"
+    "{% set loop_messages = messages %}"
+    "{% endif %}"
+    "{% for message in loop_messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{ '[INST] ' + message['content'] + ' [/INST]' }}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{ message['content'] }}"
+    "{% else %}"
+    "{{ raise_exception('Only user and assistant roles are supported!') }}"
+    "{% endif %}"
     "{% endfor %}"
-pass
+)
 
 # Check Llama chat template without BOS / EOS
-llama_template = \
-    "{% if messages[0]['role'] == 'system' %}"\
-        "{% if messages[1]['role'] == 'user' %}"\
-            "{{ '[INST] <>\n' + messages[0]['content'] + '\n<>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
-            "{% set loop_messages = messages[2:] %}"\
-        "{% else %}"\
-            "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
-            "{% set loop_messages = messages[1:] %}"\
-        "{% endif %}"\
-    "{% else %}"\
-        "{% set loop_messages = messages %}"\
-    "{% endif %}"\
-    "{% for message in loop_messages %}"\
-        "{% if message['role'] == 'user' %}"\
-            "{{ '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
-        "{% elif message['role'] == 'assistant' %}"\
-            "{{ ' ' + message['content'].strip() + ' ' }}"\
-        "{% else %}"\
-            "{{ raise_exception('Only user and assistant roles are supported!') }}"\
-        "{% endif %}"\
+llama_template = (
+    "{% if messages[0]['role'] == 'system' %}"
+    "{% if messages[1]['role'] == 'user' %}"
+    "{{ '[INST] <>\n' + messages[0]['content'] + '\n<>\n\n' + messages[1]['content'] + ' [/INST]' }}"
+    "{% set loop_messages = messages[2:] %}"
+    "{% else %}"
+    "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"
+    "{% set loop_messages = messages[1:] %}"
+    "{% endif %}"
+    "{% else %}"
+    "{% set loop_messages = messages %}"
+    "{% endif %}"
+    "{% for message in loop_messages %}"
+    "{% if message['role'] == 'user' %}"
+    "{{ '[INST] ' + message['content'].strip() + ' [/INST]' }}"
+    "{% elif message['role'] == 'assistant' %}"
+    "{{ ' ' + message['content'].strip() + ' ' }}"
+    "{% else %}"
+    "{{ raise_exception('Only user and assistant roles are supported!') }}"
+    "{% endif %}"
     "{% endfor %}"
-pass
+)
 
 
 def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
     # Get eos_token, bos_token etc
-    if not hasattr(slow_tokenizer, "all_special_tokens"): return True
+    if not hasattr(slow_tokenizer, "all_special_tokens"):
+        return True
     dir_names = dir(slow_tokenizer)
-    special_tokens = list(filter(None, (
-        getattr(slow_tokenizer, x) for x in dir_names
-        if x.endswith("_token") and x.count("_") == 1
-    )))
+    special_tokens = list(
+        filter(
+            None,
+            (
+                getattr(slow_tokenizer, x)
+                for x in dir_names
+                if x.endswith("_token") and x.count("_") == 1
+            ),
+        )
+    )
     all_special_tokens = list(set(special_tokens + slow_tokenizer.all_special_tokens))
 
     # Remove replacement char for false positive
@@ -275,7 +279,7 @@ def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
     check_chat_template1 = True
     check_chat_template2 = True
     check_chat_template3 = True
-    
+
     """
     Weirdly Mistral tokenizers are actually correct??
     Ie below will actually load mistral v1 and v3 incorrectly!
@@ -313,16 +317,20 @@ def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
     slow_tokenizer.chat_template = slow_chat_template
     fast_tokenizer.chat_template = fast_chat_template
     """
-    check_chat_template = check_chat_template1 and check_chat_template2 and check_chat_template3
+    check_chat_template = (
+        check_chat_template1 and check_chat_template2 and check_chat_template3
+    )
 
     # Try special tokens
     try:
-        string = "\n".join(all_special_tokens) + \
-            "A quick brown fox jumps over the lazy dog!!\n\nHi\n\n" + \
-            "".join(all_special_tokens)
-        check_special_tokens = \
-            slow_tokenizer(string).input_ids == \
-            fast_tokenizer(string).input_ids
+        string = (
+            "\n".join(all_special_tokens)
+            + "A quick brown fox jumps over the lazy dog!!\n\nHi\n\n"
+            + "".join(all_special_tokens)
+        )
+        check_special_tokens = (
+            slow_tokenizer(string).input_ids == fast_tokenizer(string).input_ids
+        )
 
         return check_chat_template and check_special_tokens
     except:
@@ -333,8 +341,6 @@ def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
             return check_chat_template
         else:
             return False
-    pass
-pass
 
 
 def fix_sentencepiece_tokenizer(
@@ -347,36 +353,37 @@ def fix_sentencepiece_tokenizer(
     # We need to manually edit the sentencepiece tokenizer!
     try:
         from transformers.convert_slow_tokenizer import import_protobuf
+
         sentencepiece_model_pb2 = import_protobuf()
     except Exception as e:
         try:
             import google.protobuf
             from unsloth_zoo.utils import Version
+
             protobuf_version = Version(google.protobuf.__version__)
             if protobuf_version > Version("3.20.3"):
                 raise RuntimeError(
-                    f"Unsloth: Your protobuf version = {protobuf_version} is too new.\n"\
+                    f"Unsloth: Your protobuf version = {protobuf_version} is too new.\n"
                     f"Please downgrade via `pip install --force-reinstall protobuf==3.20.3`"
                 )
         except:
             # This will only work for older SentencePiece versions <= 3.20.3
             from transformers.utils import sentencepiece_model_pb2
-    pass
 
     if not os.path.exists(temporary_location):
         os.makedirs(temporary_location)
-    pass
 
     # Check if tokenizer.model exists
     if not os.path.isfile(f"{temporary_location}/tokenizer.model"):
         return new_tokenizer
-    pass
 
     # First save the old tokenizer
     old_tokenizer.save_pretrained(temporary_location)
 
     tokenizer_file = sentencepiece_model_pb2.ModelProto()
-    tokenizer_file.ParseFromString(open(f"{temporary_location}/tokenizer.model", "rb").read())
+    tokenizer_file.ParseFromString(
+        open(f"{temporary_location}/tokenizer.model", "rb").read()
+    )
 
     # Now save the new tokenizer
     new_tokenizer.save_pretrained(temporary_location)
@@ -385,48 +392,47 @@ def fix_sentencepiece_tokenizer(
     for old_token, new_token in token_mapping.items():
         ids = old_tokenizer([old_token], add_special_tokens = False).input_ids
         ids = ids[0]
-        if (len(ids) != 1):
+        if len(ids) != 1:
             # Skip this token!
-            print(f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!")
+            print(
+                f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!"
+            )
             continue
-        pass
         ids = ids[0]
         # [TODO] Hack for Starling - try except
         try:
             tokenizer_piece = tokenizer_file.pieces[ids]
         except:
             continue
-        assert(tokenizer_piece.piece == old_token)
+        assert tokenizer_piece.piece == old_token
         tokenizer_piece.piece = new_token
-    pass
 
     # And now write it
     with open(f"{temporary_location}/tokenizer.model", "wb") as file:
         file.write(tokenizer_file.SerializeToString())
-    pass
 
     # And load it!
     from transformers import AutoTokenizer
+
     tokenizer = AutoTokenizer.from_pretrained(
         temporary_location,
         eos_token = new_tokenizer.eos_token,
         pad_token = new_tokenizer.pad_token,
     )
     return tokenizer
-pass
 
 
 def fix_sentencepiece_gguf(saved_location):
     """
-        Fixes sentencepiece tokenizers which did not extend the vocabulary with
-        user defined tokens.
-        Inspiration from https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py
+    Fixes sentencepiece tokenizers which did not extend the vocabulary with
+    user defined tokens.
+    Inspiration from https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py
     """
     from copy import deepcopy
     from transformers.utils import sentencepiece_model_pb2
     import json
     from enum import IntEnum
-    
+
     class SentencePieceTokenTypes(IntEnum):
         NORMAL = 1
         UNKNOWN = 2
@@ -434,54 +440,58 @@ class SentencePieceTokenTypes(IntEnum):
         USER_DEFINED = 4
         UNUSED = 5
         BYTE = 6
-    pass
 
     # Load tokenizer.model
     tokenizer_file = sentencepiece_model_pb2.ModelProto()
-    if not os.path.isfile(f"{saved_location}/tokenizer.model"): return
-    tokenizer_file.ParseFromString(open(f"{saved_location}/tokenizer.model", "rb").read())
+    if not os.path.isfile(f"{saved_location}/tokenizer.model"):
+        return
+    tokenizer_file.ParseFromString(
+        open(f"{saved_location}/tokenizer.model", "rb").read()
+    )
     sentence_piece_size = len(tokenizer_file.pieces)
 
     # Load added_tokens_json
-    if not os.path.isfile(f"{saved_location}/added_tokens.json"): return
+    if not os.path.isfile(f"{saved_location}/added_tokens.json"):
+        return
     with open(f"{saved_location}/added_tokens.json", "r", encoding = "utf-8") as file:
         added_tokens_json = json.load(file)
-    pass
-    if len(added_tokens_json) == 0: return
+    if len(added_tokens_json) == 0:
+        return
 
-    added_tokens_json = dict(sorted(added_tokens_json.items(), key = lambda item: item[1]))
+    added_tokens_json = dict(
+        sorted(added_tokens_json.items(), key = lambda item: item[1])
+    )
     new_size = sentence_piece_size + len(added_tokens_json)
 
     # Confirm added_tokens_json is correct
     added_tokens_ids = np.array(list(added_tokens_json.values()))
     diff = np.diff(added_tokens_ids)
-    if (diff.min() != 1 or diff.max() != 1): return
-    if (added_tokens_ids.min() != sentence_piece_size): return
+    if diff.min() != 1 or diff.max() != 1:
+        return
+    if added_tokens_ids.min() != sentence_piece_size:
+        return
 
     # Edit sentence piece tokens with added_tokens_json
     logger.warning(
-        f"Unsloth: Extending {saved_location}/tokenizer.model with added_tokens.json.\n"\
-        f"Originally tokenizer.model is of size ({sentence_piece_size}).\n"\
+        f"Unsloth: Extending {saved_location}/tokenizer.model with added_tokens.json.\n"
+        f"Originally tokenizer.model is of size ({sentence_piece_size}).\n"
         f"But we need to extend to sentencepiece vocab size ({new_size})."
     )
-    new_tokens = deepcopy(tokenizer_file.pieces[-len(added_tokens_ids):])
+    new_tokens = deepcopy(tokenizer_file.pieces[-len(added_tokens_ids) :])
     for new_token, added_token in zip(new_tokens, added_tokens_json.keys()):
         new_token.piece = added_token.encode("utf-8")
         new_token.score = -1000.0
-        new_token.type  = SentencePieceTokenTypes.USER_DEFINED
-    pass
+        new_token.type = SentencePieceTokenTypes.USER_DEFINED
 
     tokenizer_file.pieces.extend(new_tokens)
 
     with open(f"{saved_location}/tokenizer.model", "wb") as file:
         file.write(tokenizer_file.SerializeToString())
-    pass
 
     # Add padding tokens
     # actual_vocab_size = model.config.vocab_size
     # padding = actual_vocab_size - len(tokenizer_file.pieces)
     return
-pass
 
 
 def _load_correct_tokenizer(
@@ -501,7 +511,6 @@ def _load_correct_tokenizer(
         cache_dir = os.path.join(KAGGLE_TMP, cache_dir)
     else:
         cache_dir = None
-    pass
 
     # Try loading the slow tokenizer. If it fails, then try Fast only
     # Mainly to solve Deepseek models with no tokenizer.model file
@@ -509,15 +518,15 @@ def _load_correct_tokenizer(
     try:
         slow_tokenizer = AutoTokenizer.from_pretrained(
             tokenizer_name,
-            model_max_length  = model_max_length,
-            padding_side      = padding_side,
-            token             = token,
+            model_max_length = model_max_length,
+            padding_side = padding_side,
+            token = token,
             trust_remote_code = trust_remote_code,
             # Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
-            use_fast          = False,
-            legacy            = False,
-            from_slow         = True,
-            cache_dir         = cache_dir,
+            use_fast = False,
+            legacy = False,
+            from_slow = True,
+            cache_dir = cache_dir,
         )
     except:
         slow_tokenizer = None
@@ -525,17 +534,17 @@ def _load_correct_tokenizer(
         #     f"Unsloth: {tokenizer_name} has no tokenizer.model file.\n"\
         #     "Just informing you about this - this is not a critical error."
         # )
-    pass
     # Unsure why this occurs!
-    if type(slow_tokenizer) is bool: slow_tokenizer = None
+    if type(slow_tokenizer) is bool:
+        slow_tokenizer = None
 
     fast_tokenizer = AutoTokenizer.from_pretrained(
         tokenizer_name,
-        model_max_length  = model_max_length,
-        padding_side      = padding_side,
-        token             = token,
+        model_max_length = model_max_length,
+        padding_side = padding_side,
+        token = token,
         trust_remote_code = trust_remote_code,
-        cache_dir         = cache_dir,
+        cache_dir = cache_dir,
     )
 
     if not fix_tokenizer or tokenizer_name in IGNORED_TOKENIZER_NAMES:
@@ -547,22 +556,25 @@ def _load_correct_tokenizer(
     elif "phi-4" in tokenizer_name.lower():
         return fast_tokenizer
     elif slow_tokenizer is not None:
-        if hasattr(fast_tokenizer, "add_bos_token") and hasattr(slow_tokenizer, "add_bos_token"):
+        if hasattr(fast_tokenizer, "add_bos_token") and hasattr(
+            slow_tokenizer, "add_bos_token"
+        ):
             fast_tokenizer.add_bos_token = slow_tokenizer.add_bos_token
-        if hasattr(fast_tokenizer, "add_eos_token") and hasattr(slow_tokenizer, "add_eos_token"):
+        if hasattr(fast_tokenizer, "add_eos_token") and hasattr(
+            slow_tokenizer, "add_eos_token"
+        ):
             fast_tokenizer.add_eos_token = slow_tokenizer.add_eos_token
-        
+
         # Confirm if slow and fast are equivalent!
         if assert_same_tokenization(slow_tokenizer, fast_tokenizer):
             return fast_tokenizer
         else:
-            logger.warning(f"Unsloth: Will load {tokenizer_name} as a legacy tokenizer.")
+            logger.warning(
+                f"Unsloth: Will load {tokenizer_name} as a legacy tokenizer."
+            )
             return convert_to_fast_tokenizer(slow_tokenizer)
-        pass
     else:
         return fast_tokenizer
-    pass
-pass
 
 
 def load_correct_tokenizer(
@@ -592,10 +604,13 @@ def load_correct_tokenizer(
         chat_template = old_chat_template
 
     # Also check Llama-2 old style models
-    elif old_chat_template is not None and \
-        "[/INST]" in old_chat_template and "[INST]" in old_chat_template and \
-        "bos_token" in old_chat_template and "eos_token" in old_chat_template:
-
+    elif (
+        old_chat_template is not None
+        and "[/INST]" in old_chat_template
+        and "[INST]" in old_chat_template
+        and "bos_token" in old_chat_template
+        and "eos_token" in old_chat_template
+    ):
         chat_template = old_chat_template
 
     else:
@@ -604,12 +619,9 @@ def load_correct_tokenizer(
             raise RuntimeError(
                 "Unsloth: Fixing chat template failed - please file a report immediately!"
             )
-        pass
-    pass
 
     tokenizer.chat_template = chat_template
     return tokenizer
-pass
 
 
 def _find_end_position(template, endfor, endif):
@@ -621,8 +633,6 @@ def _find_end_position(template, endfor, endif):
         return endfor
     else:
         return endif
-    pass
-pass
 
 
 def _fix_chat_template(chat_template):
@@ -635,28 +645,33 @@ def _fix_chat_template(chat_template):
         chosen_end = _find_end_position(chat_template, endfor, endif)
     if chosen_end is None:
         return chat_template
-    
+
     where = chat_template.find(chosen_end)
 
-    after_endfor = chat_template[where + len(chosen_end):]
+    after_endfor = chat_template[where + len(chosen_end) :]
 
     dash = "-" if chosen_end.startswith("{%-") else ""
 
-    if "{%" + dash + " if" not in after_endfor and "{%" + dash + " set " not in after_endfor and \
-        after_endfor.startswith("{{") and after_endfor.endswith("}}") and \
-        after_endfor.count("{{") == 1 and after_endfor.count("}}") == 1:
-
-        after_endfor = "{%" + dash + " if add_generation_prompt %}" + after_endfor + endif
+    if (
+        "{%" + dash + " if" not in after_endfor
+        and "{%" + dash + " set " not in after_endfor
+        and after_endfor.startswith("{{")
+        and after_endfor.endswith("}}")
+        and after_endfor.count("{{") == 1
+        and after_endfor.count("}}") == 1
+    ):
+        after_endfor = (
+            "{%" + dash + " if add_generation_prompt %}" + after_endfor + endif
+        )
 
-        chat_template = chat_template[:where + len(chosen_end)] + after_endfor
-    pass
+        chat_template = chat_template[: where + len(chosen_end)] + after_endfor
     return chat_template
-pass
 
 
 def fix_chat_template(tokenizer):
     chat_template = getattr(tokenizer, "chat_template", None)
-    if chat_template is None: return None
+    if chat_template is None:
+        return None
 
     ### 1. Check if add_generation_prompt works
     # Check for ShareGPT style first
@@ -665,62 +680,69 @@ def fix_chat_template(tokenizer):
         messages = [
             {"role": "user", "content": "Who are you?"},
         ]
-        tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
+        tokenizer.apply_chat_template(
+            messages, add_generation_prompt = False, tokenize = False
+        )
         is_sharegpt = False
     except:
         try:
             messages = [
                 {"from": "human", "value": "Who are you?"},
             ]
-            tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
+            tokenizer.apply_chat_template(
+                messages, add_generation_prompt = False, tokenize = False
+            )
             is_sharegpt = True
         except:
             is_sharegpt = None
-        pass
-    pass
 
     # Not ShareGPT or HF style - just return
-    if is_sharegpt is None: return chat_template
+    if is_sharegpt is None:
+        return chat_template
 
     # Tokenize
     messages = [
-        {"role": "user", "content": "Who are you?"} \
-        if not is_sharegpt else \
-        {"from": "human", "value": "Who are you?"}
+        {"role": "user", "content": "Who are you?"}
+        if not is_sharegpt
+        else {"from": "human", "value": "Who are you?"}
     ]
-    no  = tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
-    yes = tokenizer.apply_chat_template(messages, add_generation_prompt =  True, tokenize = False)
+    no = tokenizer.apply_chat_template(
+        messages, add_generation_prompt = False, tokenize = False
+    )
+    yes = tokenizer.apply_chat_template(
+        messages, add_generation_prompt = True, tokenize = False
+    )
 
     if no == yes:
         # SAME?! That's not good! We check for add_generation_prompt
-        if   "{% if add_generation_prompt %}" not in chat_template and \
-            "{%- if add_generation_prompt %}" not in chat_template:
+        if (
+            "{% if add_generation_prompt %}" not in chat_template
+            and "{%- if add_generation_prompt %}" not in chat_template
+        ):
             # Try fixing it by adding it
             new_chat_template = _fix_chat_template(chat_template)
-            if   "{% if add_generation_prompt %}" not in new_chat_template and \
-                "{%- if add_generation_prompt %}" not in new_chat_template:
+            if (
+                "{% if add_generation_prompt %}" not in new_chat_template
+                and "{%- if add_generation_prompt %}" not in new_chat_template
+            ):
                 raise RuntimeError(
-                    f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\
-                    "does not have a {% if add_generation_prompt %} for generation purposes.\n"\
+                    f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"
+                    "does not have a {% if add_generation_prompt %} for generation purposes.\n"
                     f"Please file a bug report to the maintainers of `{tokenizer.name_or_path}` - thanks!"
                 )
             else:
                 logger.warning_once(
-                    "Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n"\
+                    "Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n"
                     f"This is not a bug, but please notify the maintainers of `{tokenizer.name_or_path}` - thanks!"
                 )
                 chat_template = new_chat_template
-            pass
         else:
             raise RuntimeError(
-                f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\
-                "has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\n"\
+                f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"
+                "has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\n"
                 "Please file a bug report immediately - thanks!"
             )
-        pass
-    pass
     return chat_template
-pass
 
 
 def check_tokenizer(
@@ -741,59 +763,64 @@ def check_tokenizer(
     # We ignore some of them!
     if tokenizer.__repr__().split("(", 1)[0] in IGNORED_TOKENIZER_CHECKING:
         return tokenizer
-    pass
 
     max_embedding_size = model.model.embed_tokens.weight.shape[0]
     added_tokens_fast = tokenizer.added_tokens_decoder
-    added_tokens_fast = {index : str(value) for index, value in added_tokens_fast.items()}
+    added_tokens_fast = {
+        index: str(value) for index, value in added_tokens_fast.items()
+    }
     sorted_keys = sorted(added_tokens_fast)
-    added_tokens_fast = {key : added_tokens_fast[key] for key in sorted_keys}
+    added_tokens_fast = {key: added_tokens_fast[key] for key in sorted_keys}
 
     for j, index in enumerate(added_tokens_fast.keys()):
         if index >= max_embedding_size:
-            bad_indices = list(added_tokens_fast.keys  ())[j:]
-            bad_tokens  = list(added_tokens_fast.values())[j:]
+            bad_indices = list(added_tokens_fast.keys())[j:]
+            bad_tokens = list(added_tokens_fast.values())[j:]
             if not _reload:
                 # Try removing the token
                 added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
                 special_tokens = tokenizer.special_tokens_map
                 import itertools
+
                 special_tokens = frozenset(
                     itertools.chain.from_iterable(
                         [x] if type(x) is str else x for x in special_tokens.values()
                     )
                 )
                 can_be_removed1 = [x for x in bad_tokens if x not in special_tokens]
-                can_be_removed2 = [x for x in can_be_removed1 if x in tokenizer._added_tokens_encoder.keys()]
+                can_be_removed2 = [
+                    x
+                    for x in can_be_removed1
+                    if x in tokenizer._added_tokens_encoder.keys()
+                ]
 
                 # Check of extra tokens can in fact we removed!
-                can_be_removed = \
-                    (len(can_be_removed1) == len(bad_tokens)) and \
-                    (len(can_be_removed2) == len(bad_tokens))
+                can_be_removed = (len(can_be_removed1) == len(bad_tokens)) and (
+                    len(can_be_removed2) == len(bad_tokens)
+                )
 
                 # Check if sep_token or other generic types
                 remove_generic = False
                 try_mapper = []
                 if not can_be_removed:
                     names = dir(tokenizer)
-                    names = (x for x in names if x.endswith("_token") and x.count("_") == 1)
+                    names = (
+                        x for x in names if x.endswith("_token") and x.count("_") == 1
+                    )
                     generic_tokens = [(x, getattr(tokenizer, x, None)) for x in names]
 
                     try_removal = []
                     for token in bad_tokens:
-                        for (name_token, check_token) in generic_tokens:
+                        for name_token, check_token in generic_tokens:
                             if check_token == token:
                                 try_removal.append(token)
                                 try_mapper.append(name_token)
-                            pass
-                        pass
-                    pass
 
                     # Recheck!
-                    can_be_removed = (len(try_removal) == len(bad_tokens))
-                    if can_be_removed: remove_generic = True
+                    can_be_removed = len(try_removal) == len(bad_tokens)
+                    if can_be_removed:
+                        remove_generic = True
                     can_be_removed1 = bad_tokens
-                pass
 
                 if can_be_removed:
                     # Yes it can be fixed!
@@ -806,32 +833,26 @@ def check_tokenizer(
                             # Remove sep token for example
                             setattr(tokenizer, try_mapper[j], None)
                             setattr(tokenizer, try_mapper[j] + "_id", None)
-                        pass
-                    pass
                     # Confirm 1 more time!
                     if max(tokenizer.added_tokens_decoder.keys()) < max_embedding_size:
                         logger.warning_once(
-                            f"Unsloth loaded a broken tokenizer `{model_name}`, but managed to repair it!\n"\
-                            f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
+                            f"Unsloth loaded a broken tokenizer `{model_name}`, but managed to repair it!\n"
+                            f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"
                             "We removed these bad tokens. If you think this is incorrect, fix your tokenizer first."
                         )
                         return convert_to_fast_tokenizer(tokenizer)
-                    pass
-                pass
 
                 # :( Failure
                 raise RuntimeError(
-                    f"Unsloth tried to load `{model_name}`, but cannot succeed.\n"\
-                    f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
+                    f"Unsloth tried to load `{model_name}`, but cannot succeed.\n"
+                    f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"
                     f"Fix your tokenizer since it'll perform out of bounds memory accesses."
                 )
-            pass
-            
+
             if IS_COLAB_ENVIRONMENT or IS_KAGGLE_ENVIRONMENT:
                 cache_dir = "huggingface_tokenizers_cache"
             else:
                 cache_dir = None
-            pass
 
             # Sometimes slow tokenizer does not work like Deepseek
             try:
@@ -861,16 +882,12 @@ def check_tokenizer(
                 # Tokenizer has out of bounds issues and we can't
                 # load the slow tokenizer version :(
                 logger.warning_once(
-                    "Unsloth: Tokenizer is most likely buggy, and Unsloth failed to repair it.\n"\
-                    "It will still work, but beware of out of bounds memory accesses.\n"\
+                    "Unsloth: Tokenizer is most likely buggy, and Unsloth failed to repair it.\n"
+                    "It will still work, but beware of out of bounds memory accesses.\n"
                     "Please file an issue on the model owner's repo about this issue."
                 )
                 return tokenizer
-            pass
-        pass
-    pass
     return convert_to_fast_tokenizer(tokenizer)
-pass
 
 
 import inspect
@@ -879,9 +896,11 @@ def check_tokenizer(
 import trl.trainer.sft_trainer
 from trl.trainer.sft_trainer import *
 from transformers.trainer import *
+
 try:
     from trl.trainer.sft_trainer import neftune_post_forward_hook
 except:
+
     def neftune_post_forward_hook(module, input, output):
         """
         Implements the NEFTune forward pass for the model using forward hooks. Note this works only for
@@ -909,13 +928,11 @@ def neftune_post_forward_hook(module, input, output):
             mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
             output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
         return output
-    pass
-pass
 
 
 def patch_sft_trainer_tokenizer():
     """
-        Patches the trainer with changes
+    Patches the trainer with changes
     """
     try:
         sft_trainer = eval(f"trl.trainer.sft_trainer.SFTTrainer")
@@ -923,39 +940,50 @@ def patch_sft_trainer_tokenizer():
         return
     all_imports = dir(trl.trainer.sft_trainer)
 
-    for (function_name, replacer,) in (
+    for (
+        function_name,
+        replacer,
+    ) in (
         # ("_prepare_non_packed_dataloader", "def tokenize(element):",),
-        ("_prepare_non_packed_dataloader", None,),
-        ("_prepare_dataset", None,),
+        (
+            "_prepare_non_packed_dataloader",
+            None,
+        ),
+        (
+            "_prepare_dataset",
+            None,
+        ),
         # ("_prepare_packed_dataloader", "if dataset_text_field is not None",),
     ):
-        if not hasattr(sft_trainer, function_name): continue
+        if not hasattr(sft_trainer, function_name):
+            continue
 
         function = getsource(eval(f"sft_trainer.{function_name}"))
         where = function.find("def")
         function = function.split("\n")
         function = "\n".join(x[where:] for x in function)
 
-        check_text = \
-        "\n"\
-        "if 'tokenizer'          not in locals(): tokenizer = processing_class\n"\
-        "if 'formatting_func'    not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\
-        "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"\
-        "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"\
-        "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"\
-        "chat_template = getattr(tokenizer, 'chat_template', None)\n"\
-        "chat_template = '' if chat_template is None else chat_template\n"\
-        "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\
-        "if getattr(tokenizer, 'bos_token', None) is not None else False\n"\
-        "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"\
-        "    from functools import partial\n"\
-        "    tokenizer = partial(tokenizer, add_special_tokens = False)\n"\
-        "    processing_class = tokenizer\n"\
-        "else:\n"\
-        "    add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n"
+        check_text = (
+            "\n"
+            "if 'tokenizer'          not in locals(): tokenizer = processing_class\n"
+            "if 'formatting_func'    not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"
+            "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"
+            "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"
+            "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"
+            "chat_template = getattr(tokenizer, 'chat_template', None)\n"
+            "chat_template = '' if chat_template is None else chat_template\n"
+            "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "
+            "if getattr(tokenizer, 'bos_token', None) is not None else False\n"
+            "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"
+            "    from functools import partial\n"
+            "    tokenizer = partial(tokenizer, add_special_tokens = False)\n"
+            "    processing_class = tokenizer\n"
+            "else:\n"
+            "    add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n"
+        )
 
         check_text = check_text.split("\n")
-        check_text = "\n".join(" "*where + x for x in check_text)
+        check_text = "\n".join(" " * where + x for x in check_text)
         check_text = check_text.rstrip() + "\n"
 
         if replacer is None:
@@ -965,101 +993,111 @@ def patch_sft_trainer_tokenizer():
                 function,
                 flags = re.MULTILINE | re.DOTALL,
             )
-            if len(replacer) == 0: continue
+            if len(replacer) == 0:
+                continue
             replacer = replacer[0]
             function = function.replace(replacer, replacer + check_text)
         else:
             function = function.replace(replacer, check_text + replacer)
-        pass
 
         x = [x for x in all_imports if x in function]
         exec(f"from trl.trainer.sft_trainer import ({','.join(x)})", locals())
         exec(function, locals(), globals())
-        exec(f"trl.trainer.sft_trainer.SFTTrainer.{function_name} = {function_name}", globals())
-    pass
+        exec(
+            f"trl.trainer.sft_trainer.SFTTrainer.{function_name} = {function_name}",
+            globals(),
+        )
 
     # Patch train with fix_untrained_tokens
-    for path_to_trainer in \
-        ("sft_trainer.SFTTrainer", "dpo_trainer.DPOTrainer", "kto_trainer.KTOTrainer"):
-
+    for path_to_trainer in (
+        "sft_trainer.SFTTrainer",
+        "dpo_trainer.DPOTrainer",
+        "kto_trainer.KTOTrainer",
+    ):
         function_name, replacer = "train", "if resume_from_checkpoint is False:"
         function = getsource(eval(f"trl.trainer.{path_to_trainer}.{function_name}"))
         where = function.find("def")
         function = function.split("\n")
         function = "\n".join(x[where:] for x in function)
 
-        check_text = \
-        "\n"\
-        "import subprocess, re, gc, numpy as np\n"\
-        "a = np.array([0,])\n"\
-        "try:\n"\
-        "    a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\n"\
-        "    a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)\n"\
-        "    a = np.array([int(x.decode('utf-8'))/1024 for x in a])\n"\
-        "except:\n"\
-        "    if not torch.cuda.is_available():\n"\
-        "        raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')\n"\
-        "if ((a - PRE_CHECK) >= 1).sum() > 1:\n"\
-        "    raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')\n"\
-        "for _ in range(3):\n"\
-        "    gc.collect()\n"\
-        "    torch.cuda.empty_cache()\n"\
-        "pass\n"\
-        "\n"\
-        "tokenizer = self.processing_class if hasattr(self, 'processing_class') else self.tokenizer\n"\
-        "fix_untrained_tokens(self.model, tokenizer, self.train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n\n"\
-        "fix_zero_training_loss(self.model, tokenizer, self.train_dataset)\n\n"
+        check_text = (
+            "\n"
+            "import subprocess, re, gc, numpy as np\n"
+            "a = np.array([0,])\n"
+            "try:\n"
+            "    a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\n"
+            "    a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)\n"
+            "    a = np.array([int(x.decode('utf-8'))/1024 for x in a])\n"
+            "except:\n"
+            "    if not torch.cuda.is_available():\n"
+            "        raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')\n"
+            "if ((a - PRE_CHECK) >= 1).sum() > 1:\n"
+            "    raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')\n"
+            "for _ in range(3):\n"
+            "    gc.collect()\n"
+            "    torch.cuda.empty_cache()\n"
+            "pass\n"
+            "\n"
+            "tokenizer = self.processing_class if hasattr(self, 'processing_class') else self.tokenizer\n"
+            "fix_untrained_tokens(self.model, tokenizer, self.train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n\n"
+            "fix_zero_training_loss(self.model, tokenizer, self.train_dataset)\n\n"
+        )
 
         # Warn on gradient accumulation steps if it's used
-        check_text += \
-        "\n"\
-        "try:\n"\
-        "    gradient_accumulation_steps = self.args.gradient_accumulation_steps\n"\
-        "    if type(gradient_accumulation_steps) is int and gradient_accumulation_steps > 1:\n"\
-        "        from transformers import __version__ as transformers_version\n"\
-        "        from packaging.version import Version\n"\
-        "        if Version(transformers_version) <= Version('4.45.2'):\n"\
-        "            print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\\\n"\
-        "                  '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`')\n"\
-        "except:\n"\
-        "    pass\n"\
-        "\n\n"
+        check_text += (
+            "\n"
+            "try:\n"
+            "    gradient_accumulation_steps = self.args.gradient_accumulation_steps\n"
+            "    if type(gradient_accumulation_steps) is int and gradient_accumulation_steps > 1:\n"
+            "        from transformers import __version__ as transformers_version\n"
+            "        from packaging.version import Version\n"
+            "        if Version(transformers_version) <= Version('4.45.2'):\n"
+            "            print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\\\n"
+            "                  '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`')\n"
+            "except:\n"
+            "    pass\n"
+            "\n\n"
+        )
 
         # Add NEFTune since it doesn't seem to work?? We need to manually inject it
-        check_text += \
-        "\n"\
-        "if hasattr(self, 'neftune_hook_handle'):\n"\
-        "    self.neftune_hook_handle.remove()\n"\
-        "    if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"\
-        "\n"\
-        "if getattr(self, 'neftune_noise_alpha', None) is not None:\n"\
-        "    self.model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"\
-        "    self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\n"\
-        "pass\n"\
-        "\n"
+        check_text += (
+            "\n"
+            "if hasattr(self, 'neftune_hook_handle'):\n"
+            "    self.neftune_hook_handle.remove()\n"
+            "    if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"
+            "\n"
+            "if getattr(self, 'neftune_noise_alpha', None) is not None:\n"
+            "    self.model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"
+            "    self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\n"
+            "pass\n"
+            "\n"
+        )
 
         # Also DPO weirdly tokenizes non numeric columns? Delete them!
-        check_text += \
-        "\n"\
-        "if hasattr(self.train_dataset, 'column_names'):\n"\
-        "    column_names = set(self.train_dataset.column_names)\n"\
-        "    check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\
-        "        'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\
-        "        'prompt_input_ids', 'prompt_attention_mask']\n"\
-        "    if all(x in column_names for x in check):\n"\
-        "        self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\
-        "    del check, column_names\n"\
-        "\n"
+        check_text += (
+            "\n"
+            "if hasattr(self.train_dataset, 'column_names'):\n"
+            "    column_names = set(self.train_dataset.column_names)\n"
+            "    check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"
+            "        'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"
+            "        'prompt_input_ids', 'prompt_attention_mask']\n"
+            "    if all(x in column_names for x in check):\n"
+            "        self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"
+            "    del check, column_names\n"
+            "\n"
+        )
 
         check_text = check_text.split("\n")
-        check_text = "\n".join(" "*where + x for x in check_text)
+        check_text = "\n".join(" " * where + x for x in check_text)
 
         function = function.replace(replacer, check_text + replacer)
         exec(function, globals())
 
-        exec(f"trl.trainer.{path_to_trainer}.{function_name} = {function_name}", globals())
-    pass
-pass
+        exec(
+            f"trl.trainer.{path_to_trainer}.{function_name} = {function_name}",
+            globals(),
+        )
+
 
 # Finally patch TRL tokenizer things -> moved to RL
 # patch_sft_trainer_tokenizer()
diff --git a/unsloth/trainer.py b/unsloth/trainer.py
index 6a42fe4a2..6bdb5604d 100644
--- a/unsloth/trainer.py
+++ b/unsloth/trainer.py
@@ -40,38 +40,39 @@
 
 # Unsloth gradient accumulation fix:
 from transformers import __version__ as transformers_version
+
 if Version(transformers_version) > Version("4.45.2"):
+
     def unsloth_train(trainer, *args, **kwargs):
         return trainer.train(*args, **kwargs)
-    pass
+
 else:
+
     def unsloth_train(trainer, *args, **kwargs):
         if len(args) != 0 or len(kwargs) != 0:
             raise RuntimeError(
-                "Unsloth: Our custom gradient accumulation fixed trainer does not support other arguments.\n"\
-                "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\
-                '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`'
+                "Unsloth: Our custom gradient accumulation fixed trainer does not support other arguments.\n"
+                "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"
+                "`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`"
             )
         print(
-            "Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"\
-            "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\
-            '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`'
+            "Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"
+            "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"
+            "`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`"
         )
         return _unsloth_train(trainer)
-    pass
-pass
+
 
 try:
     from trl import SFTConfig as TrainingArguments
 except:
     from transformers import TrainingArguments
-pass
+
 
 class UnslothTrainingArguments(TrainingArguments):
     def __init__(self, embedding_learning_rate: float = None, *args, **kwargs):
         embedding_learning_rate = embedding_learning_rate
         super().__init__(*args, **kwargs)
-pass
 
 
 def _create_unsloth_optimizer(
@@ -83,64 +84,64 @@ def _create_unsloth_optimizer(
     lr = optimizer_kwargs["lr"]
     weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
 
-    param_groups = \
-    {
-        "non_embeddings" : {},
-        "embeddings"     : {},
+    param_groups = {
+        "non_embeddings": {},
+        "embeddings": {},
     }
 
     for name, param in model.named_parameters():
-        if not param.requires_grad: continue
+        if not param.requires_grad:
+            continue
         if name.endswith("modules_to_save.default.weight"):
-            partial_name = name[:-len(".modules_to_save.default.weight")]
-            partial_name = partial_name[partial_name.rfind(".")+1:]
-            print(f"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}.")
-            param_groups["embeddings"]    [name] = param
+            partial_name = name[: -len(".modules_to_save.default.weight")]
+            partial_name = partial_name[partial_name.rfind(".") + 1 :]
+            print(
+                f"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}."
+            )
+            param_groups["embeddings"][name] = param
         else:
             param_groups["non_embeddings"][name] = param
-        pass
-    pass
 
     optimizer_grouped_parameters = [
         {
-            "params"       : list(param_groups["non_embeddings"].values()),
-            "weight_decay" : weight_decay,
-            "lr"           : lr,
+            "params": list(param_groups["non_embeddings"].values()),
+            "weight_decay": weight_decay,
+            "lr": lr,
         },
         {
-            "params"       : list(param_groups["embeddings"].values()),
-            "weight_decay" : weight_decay,
-            "lr"           : embedding_lr,
+            "params": list(param_groups["embeddings"].values()),
+            "weight_decay": weight_decay,
+            "lr": embedding_lr,
         },
     ]
     optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
     return optimizer
-pass
 
 
 class UnslothTrainer(SFTTrainer):
     def create_optimizer(self):
         embedding_learning_rate = getattr(self.args, "embedding_learning_rate", None)
-        if embedding_learning_rate is None: return super().create_optimizer()
+        if embedding_learning_rate is None:
+            return super().create_optimizer()
 
         if self.optimizer is None:
-            optimizer_cls, optimizer_kwargs = SFTTrainer.get_optimizer_cls_and_kwargs(self.args)
+            optimizer_cls, optimizer_kwargs = SFTTrainer.get_optimizer_cls_and_kwargs(
+                self.args
+            )
             self.optimizer = _create_unsloth_optimizer(
                 self.model,
                 optimizer_cls,
                 optimizer_kwargs,
                 embedding_learning_rate,
             )
-        pass
         return self.optimizer
-    pass
-pass
+
 
 # From `trl>=0.13.0`, they changed how to pass several params to the trainer
 # We need to patch to make the transition smooth
 def _backwards_compatible_trainer(trainer_class, config_class):
     original_init = trainer_class.__init__
-    
+
     @wraps(original_init)
     def new_init(self, *args, **kwargs):
         # All Trainer tokenizer are now called processing_class
@@ -148,21 +149,21 @@ def new_init(self, *args, **kwargs):
 
         if "processing_class" in trainer_params and "tokenizer" in kwargs:
             kwargs["processing_class"] = kwargs.pop("tokenizer")
-        pass
 
         if ("args" in kwargs) and (Version(trl.__version__) >= Version("0.13.0.dev0")):
             training_args = kwargs.pop("args", None)
 
             # Get parameters that Trainer.__init__ actually expects
-            trainer_params.remove('self')
-            trainer_params.remove('args')
+            trainer_params.remove("self")
+            trainer_params.remove("args")
 
             # Get fields that should be passed to Config init
             config_fields = {
-                field.name: field for field in dataclasses.fields(config_class) 
+                field.name: field
+                for field in dataclasses.fields(config_class)
                 if field.init
             }
-            
+
             # Create config dict with valid fields from training_args
             config_dict = {
                 name: getattr(training_args, name)
@@ -172,22 +173,22 @@ def new_init(self, *args, **kwargs):
 
             # Get parameters that exist in Config but not in TrainingArguments
             from transformers import TrainingArguments
-            moved_params = \
-                set(inspect.signature(config_class)     .parameters.keys()) - \
-                set(inspect.signature(TrainingArguments).parameters.keys())
-            
+
+            moved_params = set(inspect.signature(config_class).parameters.keys()) - set(
+                inspect.signature(TrainingArguments).parameters.keys()
+            )
+
             # Separate kwargs into trainer kwargs and config kwargs
             trainer_kwargs = {}
             additional_config_kwargs = {}
 
             for key, value in kwargs.items():
-                if key in trainer_params: trainer_kwargs[key] = value
+                if key in trainer_params:
+                    trainer_kwargs[key] = value
                 elif key in moved_params or key in config_fields:
                     additional_config_kwargs[key] = value
                 else:
                     additional_config_kwargs[key] = value
-                pass
-            pass
 
             # Update config_dict with additional kwargs
             config_dict.update(additional_config_kwargs)
@@ -205,28 +206,35 @@ def new_init(self, *args, **kwargs):
             # Reconstruct kwargs for Trainer
             kwargs = trainer_kwargs
             kwargs["args"] = config
-        pass
         original_init(self, *args, **kwargs)
-    pass
+
     return new_init
-pass
 
 
 def _patch_trl_trainer():
     import trl
-    if hasattr(trl, "__UNSLOTH_BACKWARDS_COMPATIBLE__"): return
-    if Version(trl.__version__) <= Version("0.11.0"): return
+
+    if hasattr(trl, "__UNSLOTH_BACKWARDS_COMPATIBLE__"):
+        return
+    if Version(trl.__version__) <= Version("0.11.0"):
+        return
 
     import trl.trainer
+
     trl_classes = dir(trl.trainer)
-    trl_trainers = set(x[:-len("Trainer")] for x in trl_classes if x.endswith("Trainer"))
-    trl_configs  = set(x[:-len("Config")]  for x in trl_classes if x.endswith("Config"))
+    trl_trainers = set(
+        x[: -len("Trainer")] for x in trl_classes if x.endswith("Trainer")
+    )
+    trl_configs = set(x[: -len("Config")] for x in trl_classes if x.endswith("Config"))
     trl_classes = list(trl_trainers & trl_configs)
 
     for x in trl_classes:
-        try:    exec(f"trl.{x}Trainer.__init__ = _backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)", globals())
-        except: continue
-    pass
+        try:
+            exec(
+                f"trl.{x}Trainer.__init__ = _backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)",
+                globals(),
+            )
+        except:
+            continue
 
     trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ = True
-pass
diff --git a/unsloth/utils/hf_hub.py b/unsloth/utils/hf_hub.py
index 30255b863..75df00fbf 100644
--- a/unsloth/utils/hf_hub.py
+++ b/unsloth/utils/hf_hub.py
@@ -36,7 +36,7 @@ def get_model_info(
     if _HFAPI is None:
         _HFAPI = HfApi()
     try:
-        model_info: ModelInfo = _HFAPI.model_info(model_id, expand=properties)
+        model_info: ModelInfo = _HFAPI.model_info(model_id, expand = properties)
     except Exception as e:
         print(f"Error getting model info for {model_id}: {e}")
         model_info = None
@@ -68,11 +68,11 @@ def list_models(
         properties = None
 
     models: list[ModelInfo] = _HFAPI.list_models(
-        author=author,
-        search=search,
-        sort=sort,
-        limit=limit,
-        expand=properties,
-        full=full,
+        author = author,
+        search = search,
+        sort = sort,
+        limit = limit,
+        expand = properties,
+        full = full,
     )
     return models