@@ -152,7 +152,7 @@ def test_partial_rotary_preserves_passthrough(self):
152152 x [..., rotary_dim :],
153153 rtol = 0 ,
154154 atol = 0 ,
155- msg = "Pass-through dimensions should be exactly preserved"
155+ msg = "Pass-through dimensions should be exactly preserved" ,
156156 )
157157
158158 def test_partial_rotary_different_factors (self ):
@@ -174,6 +174,135 @@ def test_partial_rotary_different_factors(self):
174174 # Verify pass-through is preserved
175175 torch .testing .assert_close (result [..., rotary_dim :], x_pass )
176176
177+ def test_float32_computation_with_fp16_input (self ):
178+ """Test that computation happens in float32 even with fp16 input"""
179+ batch_size = 2
180+ seq_len = 4
181+ num_heads = 8
182+ head_dim = 64
183+
184+ # Create fp16 input
185+ x_fp16 = torch .randn (batch_size , seq_len , num_heads , head_dim , dtype = torch .float16 )
186+ cos = torch .randn (seq_len , head_dim // 2 )
187+ sin = torch .randn (seq_len , head_dim // 2 )
188+
189+ # Apply rotary embedding
190+ result = apply_rotary_emb (x_fp16 , cos , sin )
191+
192+ # Output should be fp16
193+ assert result .dtype == torch .float16
194+
195+ # Compare with fp32 computation for numerical accuracy
196+ x_fp32 = x_fp16 .to (torch .float32 )
197+ result_fp32 = apply_rotary_emb (x_fp32 , cos , sin )
198+
199+ # The fp16 result should be close to the fp32 result when cast to fp32
200+ torch .testing .assert_close (result .to (torch .float32 ), result_fp32 , rtol = 1e-3 , atol = 1e-3 )
201+
202+ def test_float32_computation_with_bfloat16_input (self ):
203+ """Test that computation happens in float32 even with bfloat16 input"""
204+ batch_size = 2
205+ seq_len = 4
206+ num_heads = 8
207+ head_dim = 64
208+
209+ # Create bfloat16 input
210+ x_bf16 = torch .randn (batch_size , seq_len , num_heads , head_dim , dtype = torch .bfloat16 )
211+ cos = torch .randn (seq_len , head_dim // 2 )
212+ sin = torch .randn (seq_len , head_dim // 2 )
213+
214+ # Apply rotary embedding
215+ result = apply_rotary_emb (x_bf16 , cos , sin )
216+
217+ # Output should be bfloat16
218+ assert result .dtype == torch .bfloat16
219+
220+ # Compare with fp32 computation for numerical accuracy
221+ x_fp32 = x_bf16 .to (torch .float32 )
222+ result_fp32 = apply_rotary_emb (x_fp32 , cos , sin )
223+
224+ # The bf16 result should be close to the fp32 result when cast to fp32
225+ torch .testing .assert_close (result .to (torch .float32 ), result_fp32 , rtol = 1e-2 , atol = 1e-2 )
226+
227+ def test_cos_sin_dtype_independence (self ):
228+ """Test that cos/sin dtype doesn't affect output dtype"""
229+ batch_size = 2
230+ seq_len = 4
231+ num_heads = 8
232+ head_dim = 64
233+
234+ x = torch .randn (batch_size , seq_len , num_heads , head_dim , dtype = torch .float16 )
235+
236+ # Test with different cos/sin dtypes
237+ for cos_sin_dtype in [torch .float32 , torch .float16 , torch .bfloat16 ]:
238+ cos = torch .randn (seq_len , head_dim // 2 , dtype = cos_sin_dtype )
239+ sin = torch .randn (seq_len , head_dim // 2 , dtype = cos_sin_dtype )
240+
241+ result = apply_rotary_emb (x , cos , sin )
242+
243+ # Output dtype should match input x dtype, not cos/sin dtype
244+ assert result .dtype == x .dtype
245+
246+ def test_partial_rotary_float32_computation_with_fp16 (self ):
247+ """Test that partial rotary also uses float32 computation with fp16 input"""
248+ batch_size = 2
249+ seq_len = 4
250+ num_heads = 8
251+ head_dim = 64
252+ rotary_dim = 32 # Only rotate half the dimensions
253+
254+ # Create fp16 input
255+ x_fp16 = torch .randn (batch_size , seq_len , num_heads , head_dim , dtype = torch .float16 )
256+ cos = torch .randn (seq_len , rotary_dim // 2 )
257+ sin = torch .randn (seq_len , rotary_dim // 2 )
258+
259+ # Store the pass-through part
260+ x_pass_original = x_fp16 [..., rotary_dim :].clone ()
261+
262+ # Apply rotary embedding
263+ result = apply_rotary_emb (x_fp16 , cos , sin )
264+
265+ # Output should be fp16
266+ assert result .dtype == torch .float16
267+
268+ # Pass-through dimensions should be exactly preserved (no dtype conversion artifacts)
269+ torch .testing .assert_close (result [..., rotary_dim :], x_pass_original , rtol = 0 , atol = 0 )
270+
271+ # Compare with fp32 computation
272+ x_fp32 = x_fp16 .to (torch .float32 )
273+ result_fp32 = apply_rotary_emb (x_fp32 , cos , sin )
274+
275+ # The rotated part should be close to fp32 computation
276+ torch .testing .assert_close (
277+ result [..., :rotary_dim ].to (torch .float32 ), result_fp32 [..., :rotary_dim ], rtol = 1e-3 , atol = 1e-3
278+ )
279+
280+ def test_numerical_stability_with_mixed_dtypes (self ):
281+ """Test numerical stability when x, cos, sin have different dtypes"""
282+ batch_size = 2
283+ seq_len = 4
284+ num_heads = 8
285+ head_dim = 64
286+
287+ # Test various dtype combinations
288+ dtype_combinations = [
289+ (torch .float16 , torch .float32 ),
290+ (torch .bfloat16 , torch .float32 ),
291+ (torch .float32 , torch .float16 ),
292+ ]
293+
294+ for x_dtype , cos_sin_dtype in dtype_combinations :
295+ x = torch .randn (batch_size , seq_len , num_heads , head_dim , dtype = x_dtype )
296+ cos = torch .randn (seq_len , head_dim // 2 , dtype = cos_sin_dtype )
297+ sin = torch .randn (seq_len , head_dim // 2 , dtype = cos_sin_dtype )
298+
299+ # Should not raise any errors
300+ result = apply_rotary_emb (x , cos , sin )
301+
302+ # Output should match input x dtype
303+ assert result .dtype == x_dtype
304+ assert result .shape == x .shape
305+
177306
178307class TestRotaryEmbedding :
179308 """Tests for RotaryEmbedding class"""
@@ -536,11 +665,13 @@ def test_different_batch_patterns(self):
536665 dtype = torch .float32 ,
537666 )
538667
539- position_ids = torch .tensor ([
540- [0 , 1 , 2 , 3 ], # Sequential
541- [0 , 0 , 1 , 1 ], # Repeated
542- [10 , 20 , 30 , 40 ], # Large gaps
543- ])
668+ position_ids = torch .tensor (
669+ [
670+ [0 , 1 , 2 , 3 ], # Sequential
671+ [0 , 0 , 1 , 1 ], # Repeated
672+ [10 , 20 , 30 , 40 ], # Large gaps
673+ ]
674+ )
544675
545676 freqs_cis = position_ids_to_freqs_cis (rope , position_ids , qkv_format = "bshd" )
546677
@@ -601,10 +732,7 @@ def test_freqs_cis_consistency_across_ranks(self, cp_size, cp_rank):
601732 if len (indices ) > 1 :
602733 # All tokens at this position should have identical freqs_cis
603734 for i in range (1 , len (indices )):
604- torch .testing .assert_close (
605- freqs_cis_rank [indices [0 ]],
606- freqs_cis_rank [indices [i ]]
607- )
735+ torch .testing .assert_close (freqs_cis_rank [indices [0 ]], freqs_cis_rank [indices [i ]])
608736
609737 def test_freqs_cis_cp_with_variable_sequence_lengths (self ):
610738 """Test freqs_cis with variable-length sequences and CP splitting"""
@@ -697,7 +825,7 @@ def test_full_rope_pipeline(self):
697825 # Step 2: Extract cos and sin from freqs_cis
698826 # freqs_cis contains concatenated cos and sin
699827 cos = freqs_cis [..., :32 ] # First half is cos
700- sin = freqs_cis [..., 32 :] # Second half is sin
828+ sin = freqs_cis [..., 32 :] # Second half is sin
701829
702830 # Step 3: Apply rotary embeddings
703831 x = torch .randn (batch_size , seq_len , num_heads , 64 )
0 commit comments