Skip to content

Commit 810338e

Browse files
committed
WIP batch invariance
1 parent b9e260f commit 810338e

File tree

1 file changed

+314
-0
lines changed

1 file changed

+314
-0
lines changed

examples/batch_invariance_test.py

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
"""
2+
Test for batch size invariance in FlexAttention.
3+
4+
This module tests whether FlexAttention implementations produce identical results
5+
when processing entries individually vs. in batch. For any given (b, h) position,
6+
the attention output should be the same whether computed in isolation or as part
7+
of a larger batch.
8+
"""
9+
10+
from typing import Optional, Dict, Any, List, Tuple
11+
import torch
12+
import torch.nn.functional as F
13+
from torch.nn.attention.flex_attention import (
14+
flex_attention,
15+
create_block_mask,
16+
_score_mod_signature,
17+
_mask_mod_signature,
18+
)
19+
20+
from attn_gym.masks import (
21+
causal_mask,
22+
generate_sliding_window,
23+
generate_prefix_lm_mask,
24+
)
25+
from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap
26+
27+
28+
def test_batch_invariance(
29+
score_mod: Optional[_score_mod_signature] = None,
30+
mask_mod: Optional[_mask_mod_signature] = None,
31+
B: int = 4,
32+
H: int = 8,
33+
S: int = 128,
34+
D: int = 64,
35+
tolerance: float = 1e-5,
36+
device: str = "cuda",
37+
data_type: torch.dtype = torch.float16,
38+
seed: int = 42,
39+
) -> Dict[str, Any]:
40+
"""
41+
Test batch invariance for FlexAttention with given configurations.
42+
43+
Args:
44+
score_mod: Optional score modification function
45+
mask_mod: Optional mask modification function
46+
B: Batch size for testing
47+
H: Number of attention heads
48+
S: Sequence length
49+
D: Head dimension
50+
tolerance: Numerical tolerance for comparison
51+
device: Device to run on
52+
data_type: Data type for tensors
53+
seed: Random seed for reproducibility
54+
55+
Returns:
56+
Dictionary with test results including pass/fail status and metrics
57+
"""
58+
torch.manual_seed(seed)
59+
60+
# Generate random input tensors
61+
qkv_batched = [
62+
torch.randn(B, H, S, D, device=device, dtype=data_type)
63+
for _ in range(3)
64+
]
65+
66+
# Create block mask if mask_mod is provided
67+
block_mask = None
68+
if mask_mod is not None:
69+
block_mask = create_block_mask(mask_mod, B, H, S, S, device=device)
70+
71+
# Compute batched attention
72+
flex_attention_fn = torch.compile(flex_attention, dynamic=False)
73+
batched_output = flex_attention_fn(
74+
*qkv_batched,
75+
score_mod=score_mod,
76+
block_mask=block_mask
77+
)
78+
79+
# Compute individual attention for each batch element
80+
individual_outputs = []
81+
for b in range(B):
82+
qkv_individual = [tensor[b:b+1] for tensor in qkv_batched]
83+
84+
# Create block mask for single batch element if needed
85+
individual_block_mask = None
86+
if mask_mod is not None:
87+
individual_block_mask = create_block_mask(mask_mod, 1, H, S, S, device=device)
88+
89+
individual_output = flex_attention_fn(
90+
*qkv_individual,
91+
score_mod=score_mod,
92+
block_mask=individual_block_mask
93+
)
94+
individual_outputs.append(individual_output)
95+
96+
# Concatenate individual outputs
97+
individual_concat = torch.cat(individual_outputs, dim=0)
98+
99+
# Compare outputs
100+
max_diff = torch.max(torch.abs(batched_output - individual_concat)).item()
101+
mean_diff = torch.mean(torch.abs(batched_output - individual_concat)).item()
102+
103+
# Check if test passes
104+
test_passed = max_diff <= tolerance
105+
106+
# Find positions with largest differences for debugging
107+
diff_tensor = torch.abs(batched_output - individual_concat)
108+
max_diff_idx = torch.unravel_index(torch.argmax(diff_tensor), diff_tensor.shape)
109+
110+
return {
111+
"passed": test_passed,
112+
"max_difference": max_diff,
113+
"mean_difference": mean_diff,
114+
"tolerance": tolerance,
115+
"max_diff_position": {
116+
"batch": max_diff_idx[0].item(),
117+
"head": max_diff_idx[1].item(),
118+
"seq_q": max_diff_idx[2].item(),
119+
"dim": max_diff_idx[3].item(),
120+
},
121+
"config": {
122+
"B": B, "H": H, "S": S, "D": D,
123+
"has_score_mod": score_mod is not None,
124+
"has_mask_mod": mask_mod is not None,
125+
}
126+
}
127+
128+
129+
def run_test_suite(
130+
test_configs: Dict[str, Dict[str, Any]],
131+
B: int = 4,
132+
H: int = 8,
133+
S: int = 128,
134+
D: int = 64,
135+
device: str = "cuda",
136+
tolerance: float = 1e-5,
137+
) -> Dict[str, Dict[str, Any]]:
138+
"""
139+
Run batch invariance tests for multiple configurations.
140+
141+
Args:
142+
test_configs: Dictionary of test configurations
143+
B, H, S, D: Tensor dimensions
144+
device: Device to run on
145+
tolerance: Numerical tolerance
146+
147+
Returns:
148+
Dictionary with results for each test configuration
149+
"""
150+
results = {}
151+
152+
print(f"Running batch invariance test suite with B={B}, H={H}, S={S}, D={D}")
153+
print(f"Device: {device}, Tolerance: {tolerance}")
154+
print("=" * 70)
155+
156+
for test_name, config in test_configs.items():
157+
print(f"Testing {test_name}...")
158+
159+
try:
160+
result = test_batch_invariance(
161+
score_mod=config.get("score_mod"),
162+
mask_mod=config.get("mask_mod"),
163+
B=B, H=H, S=S, D=D,
164+
tolerance=tolerance,
165+
device=device,
166+
)
167+
168+
status = "PASS" if result["passed"] else "FAIL"
169+
print(f" {status}: max_diff={result['max_difference']:.2e}, "
170+
f"mean_diff={result['mean_difference']:.2e}")
171+
172+
if not result["passed"]:
173+
pos = result["max_diff_position"]
174+
print(f" Max diff at batch={pos['batch']}, head={pos['head']}, "
175+
f"seq={pos['seq_q']}, dim={pos['dim']}")
176+
177+
results[test_name] = result
178+
179+
except Exception as e:
180+
print(f" ERROR: {str(e)}")
181+
results[test_name] = {
182+
"passed": False,
183+
"error": str(e),
184+
"config": config,
185+
}
186+
187+
print("=" * 70)
188+
189+
# Summary
190+
passed_tests = sum(1 for r in results.values() if r.get("passed", False))
191+
total_tests = len(results)
192+
print(f"Summary: {passed_tests}/{total_tests} tests passed")
193+
194+
return results
195+
196+
197+
# Test configurations
198+
TEST_CONFIGS = {
199+
"no_modifications": {
200+
# Pure attention without any modifications
201+
},
202+
"causal_mask": {
203+
"mask_mod": causal_mask,
204+
},
205+
"alibi_bias": {
206+
"score_mod": generate_alibi_bias(8),
207+
},
208+
"sliding_window": {
209+
"mask_mod": generate_sliding_window(window_size=32),
210+
},
211+
"prefix_lm": {
212+
"mask_mod": generate_prefix_lm_mask(prefix_length=64),
213+
},
214+
"softcap": {
215+
"score_mod": generate_tanh_softcap(30, approx=False),
216+
},
217+
"softcap_approx": {
218+
"score_mod": generate_tanh_softcap(30, approx=True),
219+
},
220+
"causal_plus_alibi": {
221+
"mask_mod": causal_mask,
222+
"score_mod": generate_alibi_bias(8),
223+
},
224+
"sliding_window_plus_softcap": {
225+
"mask_mod": generate_sliding_window(window_size=32),
226+
"score_mod": generate_tanh_softcap(30, approx=True),
227+
},
228+
}
229+
230+
231+
def main(
232+
tests: List[str] = ["all"],
233+
batch_size: int = 4,
234+
num_heads: int = 8,
235+
seq_len: int = 128,
236+
head_dim: int = 64,
237+
device: str = "cuda",
238+
tolerance: float = 1e-5,
239+
list_tests: bool = False,
240+
):
241+
"""
242+
Main function to run batch invariance tests.
243+
244+
Args:
245+
tests: List of test names to run, or ["all"] for all tests
246+
batch_size: Batch size for testing
247+
num_heads: Number of attention heads
248+
seq_len: Sequence length
249+
head_dim: Head dimension
250+
device: Device to run tests on
251+
tolerance: Numerical tolerance for comparison
252+
list_tests: If True, just list available tests and exit
253+
"""
254+
if list_tests:
255+
print("Available tests:")
256+
for test_name in TEST_CONFIGS.keys():
257+
config = TEST_CONFIGS[test_name]
258+
desc_parts = []
259+
if config.get("mask_mod"):
260+
desc_parts.append(f"mask: {config['mask_mod'].__name__}")
261+
if config.get("score_mod"):
262+
desc_parts.append(f"score: {config['score_mod'].__name__}")
263+
if not desc_parts:
264+
desc_parts.append("no modifications")
265+
print(f" {test_name}: {', '.join(desc_parts)}")
266+
return
267+
268+
# Select tests to run
269+
if "all" in tests:
270+
configs_to_run = TEST_CONFIGS
271+
else:
272+
configs_to_run = {name: TEST_CONFIGS[name] for name in tests if name in TEST_CONFIGS}
273+
274+
# Check for unknown test names
275+
unknown_tests = [name for name in tests if name not in TEST_CONFIGS and name != "all"]
276+
if unknown_tests:
277+
print(f"Warning: Unknown test names: {unknown_tests}")
278+
print(f"Available tests: {list(TEST_CONFIGS.keys())}")
279+
280+
if not configs_to_run:
281+
print("No valid tests selected. Use --list-tests to see available options.")
282+
return
283+
284+
# Set default device based on availability
285+
if device == "cuda" and not torch.cuda.is_available():
286+
print("CUDA not available, falling back to CPU")
287+
device = "cpu"
288+
289+
# Run the test suite
290+
results = run_test_suite(
291+
test_configs=configs_to_run,
292+
B=batch_size,
293+
H=num_heads,
294+
S=seq_len,
295+
D=head_dim,
296+
device=device,
297+
tolerance=tolerance,
298+
)
299+
300+
# Check if any tests failed
301+
failed_tests = [name for name, result in results.items() if not result.get("passed", False)]
302+
if failed_tests:
303+
print(f"\nFailed tests: {failed_tests}")
304+
exit(1)
305+
else:
306+
print("\nAll tests passed! ✅")
307+
308+
309+
if __name__ == "__main__":
310+
try:
311+
from jsonargparse import CLI
312+
except ImportError:
313+
raise ImportError("Be sure to run: pip install -e .'[viz]'")
314+
CLI(main)

0 commit comments

Comments
 (0)