Skip to content

Commit 3b20170

Browse files
committed
respect kwargs spacing
1 parent d22d68a commit 3b20170

File tree

104 files changed

+14750
-9447
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+14750
-9447
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ repos:
66
args:
77
- --fix
88
- --exit-non-zero-on-fix
9-
- id: ruff-format
109
- repo: local
1110
hooks:
12-
- id: enforce-kwargs-spacing
13-
name: Enforce spaces around keyword equals
14-
entry: scripts/enforce_kwargs_spacing.py
15-
language: system
11+
- id: ruff-format-with-kwargs
12+
name: Ruff format with kwarg spacing
13+
entry: scripts/run_ruff_format.py
14+
language: python
1615
types: [python]
16+
additional_dependencies:
17+
- ruff==0.6.9

scripts/enforce_kwargs_spacing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
def enforce_spacing(text: str) -> tuple[str, bool]:
1515
"""Return updated text with keyword '=' padded by spaces, plus change flag."""
16-
lines = text.splitlines(keepends = True)
16+
lines = text.splitlines(keepends=True)
1717
if not lines:
1818
return text, False
1919

@@ -68,23 +68,28 @@ def process_file(path: Path) -> bool:
6868
original = handle.read()
6969
encoding = handle.encoding
7070
except (OSError, SyntaxError) as exc: # SyntaxError from tokenize on invalid python
71-
print(f"Failed to read {path}: {exc}", file = sys.stderr)
71+
print(f"Failed to read {path}: {exc}", file=sys.stderr)
7272
return False
7373

7474
updated, changed = enforce_spacing(original)
7575
if changed:
76-
path.write_text(updated, encoding = encoding)
76+
path.write_text(updated, encoding=encoding)
7777
return changed
7878

7979

8080
def main(argv: list[str]) -> int:
81-
parser = argparse.ArgumentParser(description = __doc__)
82-
parser.add_argument("files", nargs = "+", help = "Python files to fix")
81+
parser = argparse.ArgumentParser(description=__doc__)
82+
parser.add_argument("files", nargs="+", help="Python files to fix")
8383
args = parser.parse_args(argv)
8484

8585
touched: list[Path] = []
86+
self_path = Path(__file__).resolve()
87+
8688
for entry in args.files:
8789
path = Path(entry)
90+
# Skip modifying this script to avoid self-edit loops.
91+
if path.resolve() == self_path:
92+
continue
8893
if not path.exists() or path.is_dir():
8994
continue
9095
if process_file(path):

scripts/run_ruff_format.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/usr/bin/env python3
2+
"""Run `ruff format` followed by kwarg spacing enforcement."""
3+
4+
from __future__ import annotations
5+
6+
import subprocess
7+
import sys
8+
from pathlib import Path
9+
10+
HERE = Path(__file__).resolve().parent
11+
12+
13+
def main(argv: list[str]) -> int:
14+
files = [arg for arg in argv if Path(arg).exists()]
15+
if not files:
16+
return 0
17+
18+
ruff_cmd = [sys.executable, "-m", "ruff", "format", *files]
19+
ruff_proc = subprocess.run(ruff_cmd)
20+
if ruff_proc.returncode != 0:
21+
return ruff_proc.returncode
22+
23+
spacing_script = HERE / "enforce_kwargs_spacing.py"
24+
spacing_cmd = [sys.executable, str(spacing_script), *files]
25+
spacing_proc = subprocess.run(spacing_cmd)
26+
return spacing_proc.returncode
27+
28+
29+
if __name__ == "__main__":
30+
raise SystemExit(main(sys.argv[1:]))

tests/qlora/test_hf_qlora_train_and_merge.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,45 +54,45 @@
5454
seed = 42
5555
batch_size = 5
5656
num_generations = 5
57-
tokenizer = setup_tokenizer(model_name, fixup_funcs=[fix_llama3_tokenizer])
57+
tokenizer = setup_tokenizer(model_name, fixup_funcs = [fix_llama3_tokenizer])
5858
temperature = 0.8
5959
max_new_tokens = 20
6060

61-
peft_config = get_peft_config(lora_rank=lora_rank, target_modules="all-linear")
62-
model = setup_model(model_name, quantize=True, dtype=dtype, peft_config=peft_config)
61+
peft_config = get_peft_config(lora_rank = lora_rank, target_modules = "all-linear")
62+
model = setup_model(model_name, quantize = True, dtype = dtype, peft_config = peft_config)
6363

6464
prompt = tokenizer.apply_chat_template(
65-
[USER_MESSAGE], tokenize=False, add_generation_prompt=True
65+
[USER_MESSAGE], tokenize = False, add_generation_prompt = True
6666
)
6767
with header_footer_context("Test Prompt and Answer"):
6868
print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
6969

7070
dataset: Dataset = create_dataset(
71-
tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES
71+
tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
7272
)
7373
with header_footer_context("Dataset"):
7474
print(f"Dataset: {next(iter(dataset))}")
7575

7676
training_args = SFTConfig(
77-
output_dir=output_dir,
78-
max_steps=max_steps,
79-
per_device_train_batch_size=batch_size,
80-
log_level="info",
81-
report_to="none",
82-
num_train_epochs=1,
83-
logging_steps=1,
84-
seed=seed,
85-
bf16=dtype == torch.bfloat16,
86-
fp16=dtype == torch.float16,
87-
save_strategy="no",
77+
output_dir = output_dir,
78+
max_steps = max_steps,
79+
per_device_train_batch_size = batch_size,
80+
log_level = "info",
81+
report_to = "none",
82+
num_train_epochs = 1,
83+
logging_steps = 1,
84+
seed = seed,
85+
bf16 = dtype == torch.bfloat16,
86+
fp16 = dtype == torch.float16,
87+
save_strategy = "no",
8888
)
8989

9090
with header_footer_context("Train Args"):
9191
print(training_args)
9292
print(peft_config)
9393

9494
trainer = setup_trainer(
95-
model, tokenizer, dataset, training_args, peft_config=peft_config
95+
model, tokenizer, dataset, training_args, peft_config = peft_config
9696
)
9797

9898
with header_footer_context("Model"):
@@ -108,11 +108,11 @@
108108
responses = sample_responses(
109109
model,
110110
tokenizer,
111-
prompt=prompt,
111+
prompt = prompt,
112112
**generation_args,
113113
)
114114
with header_footer_context("Responses before training"):
115-
check_responses(responses, answer=ANSWER, prompt=prompt)
115+
check_responses(responses, answer = ANSWER, prompt = prompt)
116116

117117
with header_footer_context("Peft Weights before training"):
118118
for name, stats in itertools.islice(describe_peft_weights(model), 2):
@@ -129,11 +129,11 @@
129129
responses = sample_responses(
130130
model,
131131
tokenizer,
132-
prompt=prompt,
132+
prompt = prompt,
133133
**generation_args,
134134
)
135135
with header_footer_context("Responses after training"):
136-
check_responses(responses, answer=ANSWER, prompt=prompt)
136+
check_responses(responses, answer = ANSWER, prompt = prompt)
137137

138138
model_copy = deepcopy(model)
139139

@@ -142,18 +142,18 @@
142142
responses = sample_responses(
143143
merged_model,
144144
tokenizer,
145-
prompt=prompt,
145+
prompt = prompt,
146146
**generation_args,
147147
)
148148
with header_footer_context("Responses after custom merging to 16bit"):
149-
check_responses(responses, answer=ANSWER, prompt=prompt)
149+
check_responses(responses, answer = ANSWER, prompt = prompt)
150150

151151
merged_model_peft = model_copy.merge_and_unload()
152152
responses = sample_responses(
153153
merged_model_peft,
154154
tokenizer,
155-
prompt=prompt,
155+
prompt = prompt,
156156
**generation_args,
157157
)
158158
with header_footer_context("Responses after peft merge_and_unload"):
159-
check_responses(responses, answer=ANSWER, prompt=prompt)
159+
check_responses(responses, answer = ANSWER, prompt = prompt)

tests/qlora/test_unsloth_qlora_train_and_merge.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ def get_unsloth_model_and_tokenizer(
5050
dtype: torch.dtype = torch.bfloat16,
5151
):
5252
return FastLanguageModel.from_pretrained(
53-
model_name=model_name,
54-
max_seq_length=max_seq_length,
55-
load_in_4bit=load_in_4bit,
56-
fast_inference=fast_inference,
57-
max_lora_rank=max_lora_rank,
58-
gpu_memory_utilization=gpu_memory_utilization,
59-
dtype=dtype,
53+
model_name = model_name,
54+
max_seq_length = max_seq_length,
55+
load_in_4bit = load_in_4bit,
56+
fast_inference = fast_inference,
57+
max_lora_rank = max_lora_rank,
58+
gpu_memory_utilization = gpu_memory_utilization,
59+
dtype = dtype,
6060
)
6161

6262

@@ -69,11 +69,11 @@ def get_unsloth_peft_model(
6969
):
7070
return FastLanguageModel.get_peft_model(
7171
model,
72-
r=lora_rank,
73-
target_modules=target_modules,
74-
lora_alpha=lora_rank,
75-
use_gradient_checkpointing=use_gradient_checkpointing,
76-
random_state=random_state,
72+
r = lora_rank,
73+
target_modules = target_modules,
74+
lora_alpha = lora_rank,
75+
use_gradient_checkpointing = use_gradient_checkpointing,
76+
random_state = random_state,
7777
)
7878

7979

@@ -101,48 +101,48 @@ def get_unsloth_peft_model(
101101

102102
model, tokenizer = get_unsloth_model_and_tokenizer(
103103
model_name,
104-
max_seq_length=512,
105-
load_in_4bit=True,
106-
fast_inference=False,
107-
max_lora_rank=lora_rank,
108-
dtype=dtype,
104+
max_seq_length = 512,
105+
load_in_4bit = True,
106+
fast_inference = False,
107+
max_lora_rank = lora_rank,
108+
dtype = dtype,
109109
)
110110
temperature = 0.8
111111
max_new_tokens = 20
112112

113113
model = get_unsloth_peft_model(
114114
model,
115-
lora_rank=lora_rank,
116-
target_modules=target_modules,
117-
use_gradient_checkpointing=gradient_checkpointing,
118-
random_state=seed,
115+
lora_rank = lora_rank,
116+
target_modules = target_modules,
117+
use_gradient_checkpointing = gradient_checkpointing,
118+
random_state = seed,
119119
)
120120

121121
prompt = tokenizer.apply_chat_template(
122-
[USER_MESSAGE], tokenize=False, add_generation_prompt=True
122+
[USER_MESSAGE], tokenize = False, add_generation_prompt = True
123123
)
124124

125125
with header_footer_context("Test Prompt and Answer"):
126126
print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
127127

128128
dataset: Dataset = create_dataset(
129-
tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES
129+
tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
130130
)
131131
with header_footer_context("Dataset"):
132132
print(f"Dataset: {next(iter(dataset))}")
133133

134134
training_args = SFTConfig(
135-
output_dir=output_dir,
136-
max_steps=max_steps,
137-
per_device_train_batch_size=batch_size,
138-
log_level="info",
139-
report_to="none",
140-
num_train_epochs=1,
141-
logging_steps=1,
142-
seed=seed,
143-
bf16=dtype == torch.bfloat16,
144-
fp16=dtype == torch.float16,
145-
save_strategy="no",
135+
output_dir = output_dir,
136+
max_steps = max_steps,
137+
per_device_train_batch_size = batch_size,
138+
log_level = "info",
139+
report_to = "none",
140+
num_train_epochs = 1,
141+
logging_steps = 1,
142+
seed = seed,
143+
bf16 = dtype == torch.bfloat16,
144+
fp16 = dtype == torch.float16,
145+
save_strategy = "no",
146146
)
147147

148148
with header_footer_context("Train Args"):
@@ -163,11 +163,11 @@ def get_unsloth_peft_model(
163163
responses = sample_responses(
164164
model,
165165
tokenizer,
166-
prompt=prompt,
166+
prompt = prompt,
167167
**generation_args,
168168
)
169169
with header_footer_context("Responses before training"):
170-
check_responses(responses, answer=ANSWER, prompt=prompt)
170+
check_responses(responses, answer = ANSWER, prompt = prompt)
171171
with header_footer_context("Peft Weights before training"):
172172
for name, stats in itertools.islice(describe_peft_weights(model), 2):
173173
print(f"{name}:\n{stats}")
@@ -183,29 +183,29 @@ def get_unsloth_peft_model(
183183
responses = sample_responses(
184184
model,
185185
tokenizer,
186-
prompt=prompt,
186+
prompt = prompt,
187187
**generation_args,
188188
)
189189
with header_footer_context("Responses after training"):
190-
check_responses(responses, answer=ANSWER, prompt=prompt)
190+
check_responses(responses, answer = ANSWER, prompt = prompt)
191191

192192
model.save_pretrained_merged(
193193
unsloth_merged_path,
194194
tokenizer,
195-
save_method="merged_16bit",
195+
save_method = "merged_16bit",
196196
)
197197
merged_model_unsloth, tokenizer = get_unsloth_model_and_tokenizer(
198198
unsloth_merged_path,
199-
max_seq_length=512,
200-
load_in_4bit=False,
201-
fast_inference=False,
202-
dtype=dtype,
199+
max_seq_length = 512,
200+
load_in_4bit = False,
201+
fast_inference = False,
202+
dtype = dtype,
203203
)
204204
responses = sample_responses(
205205
merged_model_unsloth,
206206
tokenizer,
207-
prompt=prompt,
207+
prompt = prompt,
208208
**generation_args,
209209
)
210210
with header_footer_context("Responses after unsloth merge to 16bit"):
211-
check_responses(responses, answer=ANSWER, prompt=prompt)
211+
check_responses(responses, answer = ANSWER, prompt = prompt)

0 commit comments

Comments
 (0)