Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: pre-commit

on:
push:
branches: [main]
pull_request:

jobs:
pre-commit:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Install pre-commit
run: |
python -m pip install --upgrade pip
pip install pre-commit

- name: Run pre-commit
run: pre-commit run --all-files
17 changes: 17 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.0
hooks:
- id: ruff
args:
- --fix
- --exit-non-zero-on-fix
- repo: local
hooks:
- id: ruff-format-with-kwargs
name: Ruff format with kwarg spacing
entry: scripts/run_ruff_format.py
language: python
types: [python]
additional_dependencies:
- ruff==0.6.9
22 changes: 22 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -865,3 +865,25 @@ amd = [
homepage = "http://www.unsloth.ai"
documentation = "https://github.com/unslothai/unsloth"
repository = "https://github.com/unslothai/unsloth"

[tool.ruff]
target-version = "py311"

[tool.ruff.lint]
select = ["E9", "F63", "F7", "F82"]
ignore = [
"E402",
"E722",
"F403",
"F405",
"F811",
"F821",
"F841",
"F401",
"E731",
"E741",
"F601",
"E712",
]

[tool.ruff.format]
179 changes: 179 additions & 0 deletions scripts/enforce_kwargs_spacing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#!/usr/bin/env python3
"""Ensure keyword arguments use spaces around '=', prune redundant pass statements."""

from __future__ import annotations

import ast
import argparse
import io
import sys
import tokenize
from collections import defaultdict
from pathlib import Path


def enforce_spacing(text: str) -> tuple[str, bool]:
"""Return updated text with keyword '=' padded by spaces, plus change flag."""
lines = text.splitlines(keepends=True)
if not lines:
return text, False

offsets: dict[int, int] = defaultdict(int)
changed = False

reader = io.StringIO(text).readline
for token in tokenize.generate_tokens(reader):
if token.type != tokenize.OP or token.string != "=":
continue

line_index = token.start[0] - 1
col = token.start[1] + offsets[line_index]

if line_index < 0 or line_index >= len(lines):
continue

line = lines[line_index]
if col >= len(line) or line[col] != "=":
continue

line_changed = False

# Insert a space before '=' when missing and not preceded by whitespace.
if col > 0 and line[col - 1] not in {" ", "\t"}:
line = f"{line[:col]} {line[col:]}"
offsets[line_index] += 1
col += 1
line_changed = True
changed = True

# Insert a space after '=' when missing and not followed by whitespace or newline.
next_index = col + 1
if next_index < len(line) and line[next_index] not in {" ", "\t", "\n", "\r"}:
line = f"{line[:next_index]} {line[next_index:]}"
offsets[line_index] += 1
line_changed = True
changed = True

if line_changed:
lines[line_index] = line

if not changed:
return text, False

return "".join(lines), True


def remove_redundant_passes(text: str) -> tuple[str, bool]:
"""Drop pass statements that share a block with other executable code."""

try:
tree = ast.parse(text)
except SyntaxError:
return text, False

redundant: list[ast.Pass] = []

def visit(node: ast.AST) -> None:
for attr in ("body", "orelse", "finalbody"):
value = getattr(node, attr, None)
if not isinstance(value, list) or len(value) <= 1:
continue
for stmt in value:
if isinstance(stmt, ast.Pass):
redundant.append(stmt)
for stmt in value:
if isinstance(stmt, ast.AST):
visit(stmt)
handlers = getattr(node, "handlers", None)
if handlers:
for handler in handlers:
visit(handler)

visit(tree)

if not redundant:
return text, False

lines = text.splitlines(keepends=True)
changed = False

for node in sorted(
redundant, key=lambda item: (item.lineno, item.col_offset), reverse=True
):
start = node.lineno - 1
end = (node.end_lineno or node.lineno) - 1
if start >= len(lines):
continue
changed = True
if start == end:
line = lines[start]
col_start = node.col_offset
col_end = node.end_col_offset or (col_start + 4)
segment = line[:col_start] + line[col_end:]
lines[start] = segment if segment.strip() else ""
continue

# Defensive fall-back for unexpected multi-line 'pass'.
prefix = lines[start][: node.col_offset]
lines[start] = prefix if prefix.strip() else ""
for idx in range(start + 1, end):
lines[idx] = ""
suffix = lines[end][(node.end_col_offset or 0) :]
lines[end] = suffix

# Normalise to ensure lines end with newlines except at EOF.
result_lines: list[str] = []
for index, line in enumerate(lines):
if not line:
continue
if index < len(lines) - 1 and not line.endswith("\n"):
result_lines.append(f"{line}\n")
else:
result_lines.append(line)

return "".join(result_lines), changed


def process_file(path: Path) -> bool:
try:
with tokenize.open(path) as handle:
original = handle.read()
encoding = handle.encoding
except (OSError, SyntaxError) as exc: # SyntaxError from tokenize on invalid python
print(f"Failed to read {path}: {exc}", file=sys.stderr)
return False

updated, changed = enforce_spacing(original)
updated, removed = remove_redundant_passes(updated)
if changed or removed:
path.write_text(updated, encoding=encoding)
return True
return False


def main(argv: list[str]) -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("files", nargs="+", help="Python files to fix")
args = parser.parse_args(argv)

touched: list[Path] = []
self_path = Path(__file__).resolve()

for entry in args.files:
path = Path(entry)
# Skip modifying this script to avoid self-edit loops.
if path.resolve() == self_path:
continue
if not path.exists() or path.is_dir():
continue
if process_file(path):
touched.append(path)

if touched:
for path in touched:
print(f"Adjusted kwarg spacing in {path}")
return 0


if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
30 changes: 30 additions & 0 deletions scripts/run_ruff_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python3
"""Run `ruff format` followed by kwarg spacing enforcement."""

from __future__ import annotations

import subprocess
import sys
from pathlib import Path

HERE = Path(__file__).resolve().parent


def main(argv: list[str]) -> int:
files = [arg for arg in argv if Path(arg).exists()]
if not files:
return 0

ruff_cmd = [sys.executable, "-m", "ruff", "format", *files]
ruff_proc = subprocess.run(ruff_cmd)
if ruff_proc.returncode != 0:
return ruff_proc.returncode

spacing_script = HERE / "enforce_kwargs_spacing.py"
spacing_cmd = [sys.executable, str(spacing_script), *files]
spacing_proc = subprocess.run(spacing_cmd)
return spacing_proc.returncode


if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))
50 changes: 25 additions & 25 deletions tests/qlora/test_hf_qlora_train_and_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,45 +54,45 @@
seed = 42
batch_size = 5
num_generations = 5
tokenizer = setup_tokenizer(model_name, fixup_funcs=[fix_llama3_tokenizer])
tokenizer = setup_tokenizer(model_name, fixup_funcs = [fix_llama3_tokenizer])
temperature = 0.8
max_new_tokens = 20

peft_config = get_peft_config(lora_rank=lora_rank, target_modules="all-linear")
model = setup_model(model_name, quantize=True, dtype=dtype, peft_config=peft_config)
peft_config = get_peft_config(lora_rank = lora_rank, target_modules = "all-linear")
model = setup_model(model_name, quantize = True, dtype = dtype, peft_config = peft_config)

prompt = tokenizer.apply_chat_template(
[USER_MESSAGE], tokenize=False, add_generation_prompt=True
[USER_MESSAGE], tokenize = False, add_generation_prompt = True
)
with header_footer_context("Test Prompt and Answer"):
print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")

dataset: Dataset = create_dataset(
tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES
tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
)
with header_footer_context("Dataset"):
print(f"Dataset: {next(iter(dataset))}")

training_args = SFTConfig(
output_dir=output_dir,
max_steps=max_steps,
per_device_train_batch_size=batch_size,
log_level="info",
report_to="none",
num_train_epochs=1,
logging_steps=1,
seed=seed,
bf16=dtype == torch.bfloat16,
fp16=dtype == torch.float16,
save_strategy="no",
output_dir = output_dir,
max_steps = max_steps,
per_device_train_batch_size = batch_size,
log_level = "info",
report_to = "none",
num_train_epochs = 1,
logging_steps = 1,
seed = seed,
bf16 = dtype == torch.bfloat16,
fp16 = dtype == torch.float16,
save_strategy = "no",
)

with header_footer_context("Train Args"):
print(training_args)
print(peft_config)

trainer = setup_trainer(
model, tokenizer, dataset, training_args, peft_config=peft_config
model, tokenizer, dataset, training_args, peft_config = peft_config
)

with header_footer_context("Model"):
Expand All @@ -108,11 +108,11 @@
responses = sample_responses(
model,
tokenizer,
prompt=prompt,
prompt = prompt,
**generation_args,
)
with header_footer_context("Responses before training"):
check_responses(responses, answer=ANSWER, prompt=prompt)
check_responses(responses, answer = ANSWER, prompt = prompt)

with header_footer_context("Peft Weights before training"):
for name, stats in itertools.islice(describe_peft_weights(model), 2):
Expand All @@ -129,11 +129,11 @@
responses = sample_responses(
model,
tokenizer,
prompt=prompt,
prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after training"):
check_responses(responses, answer=ANSWER, prompt=prompt)
check_responses(responses, answer = ANSWER, prompt = prompt)

model_copy = deepcopy(model)

Expand All @@ -142,18 +142,18 @@
responses = sample_responses(
merged_model,
tokenizer,
prompt=prompt,
prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after custom merging to 16bit"):
check_responses(responses, answer=ANSWER, prompt=prompt)
check_responses(responses, answer = ANSWER, prompt = prompt)

merged_model_peft = model_copy.merge_and_unload()
responses = sample_responses(
merged_model_peft,
tokenizer,
prompt=prompt,
prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after peft merge_and_unload"):
check_responses(responses, answer=ANSWER, prompt=prompt)
check_responses(responses, answer = ANSWER, prompt = prompt)
Loading