1
+ """
2
+ Test for batch size invariance in FlexAttention.
3
+
4
+ This module tests whether FlexAttention implementations produce identical results
5
+ when processing entries individually vs. in batch. For any given (b, h) position,
6
+ the attention output should be the same whether computed in isolation or as part
7
+ of a larger batch.
8
+ """
9
+
10
+ from typing import Optional , Dict , Any , List , Tuple
11
+ import torch
12
+ import torch .nn .functional as F
13
+ from torch .nn .attention .flex_attention import (
14
+ flex_attention ,
15
+ create_block_mask ,
16
+ _score_mod_signature ,
17
+ _mask_mod_signature ,
18
+ )
19
+
20
+ from attn_gym .masks import (
21
+ causal_mask ,
22
+ generate_sliding_window ,
23
+ generate_prefix_lm_mask ,
24
+ )
25
+ from attn_gym .mods import generate_alibi_bias , generate_tanh_softcap
26
+
27
+
28
+ def test_batch_invariance (
29
+ score_mod : Optional [_score_mod_signature ] = None ,
30
+ mask_mod : Optional [_mask_mod_signature ] = None ,
31
+ B : int = 4 ,
32
+ H : int = 8 ,
33
+ S : int = 128 ,
34
+ D : int = 64 ,
35
+ tolerance : float = 1e-5 ,
36
+ device : str = "cuda" ,
37
+ data_type : torch .dtype = torch .float16 ,
38
+ seed : int = 42 ,
39
+ ) -> Dict [str , Any ]:
40
+ """
41
+ Test batch invariance for FlexAttention with given configurations.
42
+
43
+ Args:
44
+ score_mod: Optional score modification function
45
+ mask_mod: Optional mask modification function
46
+ B: Batch size for testing
47
+ H: Number of attention heads
48
+ S: Sequence length
49
+ D: Head dimension
50
+ tolerance: Numerical tolerance for comparison
51
+ device: Device to run on
52
+ data_type: Data type for tensors
53
+ seed: Random seed for reproducibility
54
+
55
+ Returns:
56
+ Dictionary with test results including pass/fail status and metrics
57
+ """
58
+ torch .manual_seed (seed )
59
+
60
+ # Generate random input tensors
61
+ qkv_batched = [
62
+ torch .randn (B , H , S , D , device = device , dtype = data_type )
63
+ for _ in range (3 )
64
+ ]
65
+
66
+ # Create block mask if mask_mod is provided
67
+ block_mask = None
68
+ if mask_mod is not None :
69
+ block_mask = create_block_mask (mask_mod , B , H , S , S , device = device )
70
+
71
+ # Compute batched attention
72
+ flex_attention_fn = torch .compile (flex_attention , dynamic = False )
73
+ batched_output = flex_attention_fn (
74
+ * qkv_batched ,
75
+ score_mod = score_mod ,
76
+ block_mask = block_mask
77
+ )
78
+
79
+ # Compute individual attention for each batch element
80
+ individual_outputs = []
81
+ for b in range (B ):
82
+ qkv_individual = [tensor [b :b + 1 ] for tensor in qkv_batched ]
83
+
84
+ # Create block mask for single batch element if needed
85
+ individual_block_mask = None
86
+ if mask_mod is not None :
87
+ individual_block_mask = create_block_mask (mask_mod , 1 , H , S , S , device = device )
88
+
89
+ individual_output = flex_attention_fn (
90
+ * qkv_individual ,
91
+ score_mod = score_mod ,
92
+ block_mask = individual_block_mask
93
+ )
94
+ individual_outputs .append (individual_output )
95
+
96
+ # Concatenate individual outputs
97
+ individual_concat = torch .cat (individual_outputs , dim = 0 )
98
+
99
+ # Compare outputs
100
+ max_diff = torch .max (torch .abs (batched_output - individual_concat )).item ()
101
+ mean_diff = torch .mean (torch .abs (batched_output - individual_concat )).item ()
102
+
103
+ # Check if test passes
104
+ test_passed = max_diff <= tolerance
105
+
106
+ # Find positions with largest differences for debugging
107
+ diff_tensor = torch .abs (batched_output - individual_concat )
108
+ max_diff_idx = torch .unravel_index (torch .argmax (diff_tensor ), diff_tensor .shape )
109
+
110
+ return {
111
+ "passed" : test_passed ,
112
+ "max_difference" : max_diff ,
113
+ "mean_difference" : mean_diff ,
114
+ "tolerance" : tolerance ,
115
+ "max_diff_position" : {
116
+ "batch" : max_diff_idx [0 ].item (),
117
+ "head" : max_diff_idx [1 ].item (),
118
+ "seq_q" : max_diff_idx [2 ].item (),
119
+ "dim" : max_diff_idx [3 ].item (),
120
+ },
121
+ "config" : {
122
+ "B" : B , "H" : H , "S" : S , "D" : D ,
123
+ "has_score_mod" : score_mod is not None ,
124
+ "has_mask_mod" : mask_mod is not None ,
125
+ }
126
+ }
127
+
128
+
129
+ def run_test_suite (
130
+ test_configs : Dict [str , Dict [str , Any ]],
131
+ B : int = 4 ,
132
+ H : int = 8 ,
133
+ S : int = 128 ,
134
+ D : int = 64 ,
135
+ device : str = "cuda" ,
136
+ tolerance : float = 1e-5 ,
137
+ ) -> Dict [str , Dict [str , Any ]]:
138
+ """
139
+ Run batch invariance tests for multiple configurations.
140
+
141
+ Args:
142
+ test_configs: Dictionary of test configurations
143
+ B, H, S, D: Tensor dimensions
144
+ device: Device to run on
145
+ tolerance: Numerical tolerance
146
+
147
+ Returns:
148
+ Dictionary with results for each test configuration
149
+ """
150
+ results = {}
151
+
152
+ print (f"Running batch invariance test suite with B={ B } , H={ H } , S={ S } , D={ D } " )
153
+ print (f"Device: { device } , Tolerance: { tolerance } " )
154
+ print ("=" * 70 )
155
+
156
+ for test_name , config in test_configs .items ():
157
+ print (f"Testing { test_name } ..." )
158
+
159
+ try :
160
+ result = test_batch_invariance (
161
+ score_mod = config .get ("score_mod" ),
162
+ mask_mod = config .get ("mask_mod" ),
163
+ B = B , H = H , S = S , D = D ,
164
+ tolerance = tolerance ,
165
+ device = device ,
166
+ )
167
+
168
+ status = "PASS" if result ["passed" ] else "FAIL"
169
+ print (f" { status } : max_diff={ result ['max_difference' ]:.2e} , "
170
+ f"mean_diff={ result ['mean_difference' ]:.2e} " )
171
+
172
+ if not result ["passed" ]:
173
+ pos = result ["max_diff_position" ]
174
+ print (f" Max diff at batch={ pos ['batch' ]} , head={ pos ['head' ]} , "
175
+ f"seq={ pos ['seq_q' ]} , dim={ pos ['dim' ]} " )
176
+
177
+ results [test_name ] = result
178
+
179
+ except Exception as e :
180
+ print (f" ERROR: { str (e )} " )
181
+ results [test_name ] = {
182
+ "passed" : False ,
183
+ "error" : str (e ),
184
+ "config" : config ,
185
+ }
186
+
187
+ print ("=" * 70 )
188
+
189
+ # Summary
190
+ passed_tests = sum (1 for r in results .values () if r .get ("passed" , False ))
191
+ total_tests = len (results )
192
+ print (f"Summary: { passed_tests } /{ total_tests } tests passed" )
193
+
194
+ return results
195
+
196
+
197
+ # Test configurations
198
+ TEST_CONFIGS = {
199
+ "no_modifications" : {
200
+ # Pure attention without any modifications
201
+ },
202
+ "causal_mask" : {
203
+ "mask_mod" : causal_mask ,
204
+ },
205
+ "alibi_bias" : {
206
+ "score_mod" : generate_alibi_bias (8 ),
207
+ },
208
+ "sliding_window" : {
209
+ "mask_mod" : generate_sliding_window (window_size = 32 ),
210
+ },
211
+ "prefix_lm" : {
212
+ "mask_mod" : generate_prefix_lm_mask (prefix_length = 64 ),
213
+ },
214
+ "softcap" : {
215
+ "score_mod" : generate_tanh_softcap (30 , approx = False ),
216
+ },
217
+ "softcap_approx" : {
218
+ "score_mod" : generate_tanh_softcap (30 , approx = True ),
219
+ },
220
+ "causal_plus_alibi" : {
221
+ "mask_mod" : causal_mask ,
222
+ "score_mod" : generate_alibi_bias (8 ),
223
+ },
224
+ "sliding_window_plus_softcap" : {
225
+ "mask_mod" : generate_sliding_window (window_size = 32 ),
226
+ "score_mod" : generate_tanh_softcap (30 , approx = True ),
227
+ },
228
+ }
229
+
230
+
231
+ def main (
232
+ tests : List [str ] = ["all" ],
233
+ batch_size : int = 4 ,
234
+ num_heads : int = 8 ,
235
+ seq_len : int = 128 ,
236
+ head_dim : int = 64 ,
237
+ device : str = "cuda" ,
238
+ tolerance : float = 1e-5 ,
239
+ list_tests : bool = False ,
240
+ ):
241
+ """
242
+ Main function to run batch invariance tests.
243
+
244
+ Args:
245
+ tests: List of test names to run, or ["all"] for all tests
246
+ batch_size: Batch size for testing
247
+ num_heads: Number of attention heads
248
+ seq_len: Sequence length
249
+ head_dim: Head dimension
250
+ device: Device to run tests on
251
+ tolerance: Numerical tolerance for comparison
252
+ list_tests: If True, just list available tests and exit
253
+ """
254
+ if list_tests :
255
+ print ("Available tests:" )
256
+ for test_name in TEST_CONFIGS .keys ():
257
+ config = TEST_CONFIGS [test_name ]
258
+ desc_parts = []
259
+ if config .get ("mask_mod" ):
260
+ desc_parts .append (f"mask: { config ['mask_mod' ].__name__ } " )
261
+ if config .get ("score_mod" ):
262
+ desc_parts .append (f"score: { config ['score_mod' ].__name__ } " )
263
+ if not desc_parts :
264
+ desc_parts .append ("no modifications" )
265
+ print (f" { test_name } : { ', ' .join (desc_parts )} " )
266
+ return
267
+
268
+ # Select tests to run
269
+ if "all" in tests :
270
+ configs_to_run = TEST_CONFIGS
271
+ else :
272
+ configs_to_run = {name : TEST_CONFIGS [name ] for name in tests if name in TEST_CONFIGS }
273
+
274
+ # Check for unknown test names
275
+ unknown_tests = [name for name in tests if name not in TEST_CONFIGS and name != "all" ]
276
+ if unknown_tests :
277
+ print (f"Warning: Unknown test names: { unknown_tests } " )
278
+ print (f"Available tests: { list (TEST_CONFIGS .keys ())} " )
279
+
280
+ if not configs_to_run :
281
+ print ("No valid tests selected. Use --list-tests to see available options." )
282
+ return
283
+
284
+ # Set default device based on availability
285
+ if device == "cuda" and not torch .cuda .is_available ():
286
+ print ("CUDA not available, falling back to CPU" )
287
+ device = "cpu"
288
+
289
+ # Run the test suite
290
+ results = run_test_suite (
291
+ test_configs = configs_to_run ,
292
+ B = batch_size ,
293
+ H = num_heads ,
294
+ S = seq_len ,
295
+ D = head_dim ,
296
+ device = device ,
297
+ tolerance = tolerance ,
298
+ )
299
+
300
+ # Check if any tests failed
301
+ failed_tests = [name for name , result in results .items () if not result .get ("passed" , False )]
302
+ if failed_tests :
303
+ print (f"\n Failed tests: { failed_tests } " )
304
+ exit (1 )
305
+ else :
306
+ print ("\n All tests passed! ✅" )
307
+
308
+
309
+ if __name__ == "__main__" :
310
+ try :
311
+ from jsonargparse import CLI
312
+ except ImportError :
313
+ raise ImportError ("Be sure to run: pip install -e .'[viz]'" )
314
+ CLI (main )
0 commit comments