@@ -236,7 +236,7 @@ def triton_dense_dense_sparseout_matmul(
236
236
assert dense2 .stride (0 ) == 1 , "dense2 must be contiguous along B"
237
237
238
238
if K > 512 :
239
- # print("WARN - using naive matmul for large K")
239
+ print ("WARN - using naive matmul for large K" )
240
240
# naive is more efficient for large K
241
241
return (dense1 @ dense2 ).gather (1 , at_indices )
242
242
@@ -378,7 +378,6 @@ def triton_sparse_dense_matmul_kernel(
378
378
tl .store (out_ptr + pid * B + offsets_b , accum .to (sparse_values .dtype ), mask = offsets_b < B )
379
379
380
380
381
- @torch .no_grad ()
382
381
def get_sparse_representation (x , pad_val = 0 ):
383
382
"""
384
383
Efficiently extracts sparse indices and values from a batched dense tensor x.
@@ -421,33 +420,37 @@ def get_sparse_representation(x, pad_val=0):
421
420
422
421
class TritonDecoderAutograd (torch .autograd .Function ):
423
422
@staticmethod
424
- def forward (ctx , sparse_indices , sparse_values , decoder_weight ):
425
- ctx .save_for_backward (sparse_indices , sparse_values , decoder_weight )
423
+ def forward (ctx , feature_acts , decoder_weight , require_precise_feature_acts_grad : bool = True ):
424
+ sparse_indices , sparse_values = get_sparse_representation (feature_acts )
425
+ ctx .save_for_backward (sparse_indices , sparse_values , decoder_weight , torch .tensor (require_precise_feature_acts_grad ))
426
426
return triton_sparse_dense_matmul (sparse_indices , sparse_values , decoder_weight .T )
427
427
428
428
@staticmethod
429
429
def backward (ctx , * grad_outputs , ** args ):
430
430
assert len (grad_outputs ) == 1 , "grad_outputs must be a single tensor"
431
431
grad_output = grad_outputs [0 ]
432
- sparse_indices , sparse_values , decoder_weight = ctx .saved_tensors
432
+ sparse_indices , sparse_values , decoder_weight , require_precise_feature_acts_grad = ctx .saved_tensors
433
433
434
434
assert grad_output .is_contiguous (), "grad_output must be contiguous; this is probably because the subsequent op was a .sum() or something like that, which returns a non contiguous gradient"
435
435
436
436
decoder_grad = triton_sparse_transpose_dense_matmul (
437
- sparse_indices , sparse_values , grad_output , N = decoder_weight .shape [ 1 ]
437
+ sparse_indices , sparse_values , grad_output , N = decoder_weight .size ( 1 )
438
438
).T
439
+
440
+ if require_precise_feature_acts_grad .item ():
441
+ feature_acts_grad = grad_output @ decoder_weight
442
+ else :
443
+ feature_acts_grad_sparse = triton_dense_dense_sparseout_matmul (grad_output , decoder_weight , sparse_indices ) # batch_size, K
444
+ B , d_sae = sparse_indices .size (0 ), decoder_weight .size (1 )
445
+ feature_acts_grad = torch .zeros (size = (B , d_sae )).to (sparse_values ).scatter_ (dim = 1 , index = sparse_indices , src = feature_acts_grad_sparse )
439
446
440
- return (
441
- None ,
442
- triton_dense_dense_sparseout_matmul (grad_output , decoder_weight , sparse_indices ),
443
- # decoder is contiguous when transposed so this is a matching layout
444
- decoder_grad ,
445
- None ,
446
- )
447
+ # decoder is contiguous when transposed so this is a matching layout
448
+ return feature_acts_grad , decoder_grad , None
449
+
447
450
448
451
449
452
def decode_with_triton_spmm_kernel (
450
- feature_acts : Float [torch .Tensor , "batch d_sae" ], decoder_weight : Float [torch .Tensor , "d_model d_sae" ]
453
+ feature_acts : Float [torch .Tensor , "batch d_sae" ], decoder_weight : Float [torch .Tensor , "d_model d_sae" ], require_precise_feature_acts_grad : bool
451
454
):
452
455
"""
453
456
Perform sparse-dense matrix multiplication using Triton.
@@ -459,13 +462,7 @@ def decode_with_triton_spmm_kernel(
459
462
Returns:
460
463
output: (B, d_model) - The decoded output.
461
464
"""
462
- # Convert dense feature_acts into sparse representation
463
- sparse_indices , sparse_values = get_sparse_representation (feature_acts )
464
-
465
- # Perform sparse-dense multiplication using Triton
466
- output = TritonDecoderAutograd .apply (sparse_indices , sparse_values , decoder_weight .T .contiguous ().T )
467
-
468
- return output
465
+ return TritonDecoderAutograd .apply (feature_acts , decoder_weight .T .contiguous ().T , require_precise_feature_acts_grad )
469
466
470
467
471
468
if __name__ == "__main__" :
@@ -474,69 +471,62 @@ def decode_with_triton_spmm_kernel(
474
471
import triton
475
472
import triton .language as tl
476
473
477
- def test_triton_decoder_forward ():
478
- # Set parameters
479
- B , d_sae , d_model = 4 , 32 , 16 # Batch size, input dim, output dim
480
-
481
- # Create a random dense weight matrix (as in nn.Linear), size = (d_model, d_sae)
482
- decoder = nn .Linear (d_sae , d_model , bias = False , dtype = torch .float32 , device = "cuda" )
483
-
484
- # Create a random sparse input matrix
485
- dense_input = torch .randn ((B , d_sae ), dtype = torch .float32 , device = "cuda" )
486
-
487
- # Zero out some values to simulate sparsity (~70% sparsity)
488
- dense_input [torch .rand_like (dense_input ) < 0.7 ] = 0
489
-
490
- # Run our Triton-based sparse-dense multiply
491
- triton_output = decode_with_triton_spmm_kernel (dense_input , decoder .weight )
492
-
493
- # Compare against standard dense multiply (nn.Linear equivalent)
494
- torch_output = decoder (dense_input ) # Equivalent to nn.Linear
495
- assert isinstance (triton_output , torch .Tensor ), "triton_output is not a torch.Tensor"
496
- # Ensure outputs are numerically close
497
- assert torch .allclose (triton_output , torch_output , atol = 1e-4 ), "Mismatch between Triton and PyTorch outputs!"
498
-
499
- print ("✅ Triton forward pass matches nn.Linear!" )
500
-
501
- def test_triton_decoder_backward ():
474
+ def test_triton_decoder (B , d_sae , d_model , sparsity = 0.9 , dtype = torch .float32 , require_precise_feature_acts_grad = True ):
502
475
# Set parameters
503
- B , d_sae , d_model = 4 , 32 , 16 # Batch size, input dim, output dim
504
476
505
477
# Create a random dense weight matrix (as in nn.Linear)
506
- decoder = nn .Linear (d_sae , d_model , bias = False , dtype = torch . float32 , device = "cuda" )
478
+ decoder = nn .Linear (d_sae , d_model , bias = False , dtype = dtype , device = "cuda" )
507
479
508
480
# Create a random sparse input matrix
509
- dense_input = torch .randn ((B , d_sae ), dtype = torch . float32 , device = "cuda" )
481
+ dense_input = torch .randn ((B , d_sae ), dtype = dtype , device = "cuda" )
510
482
511
- # Zero out some values to simulate sparsity (~70% sparsity)
512
- dense_input [torch .rand_like (dense_input ) < 0.7 ] = 0
483
+ # Zero out some values to simulate sparsity
484
+ dense_input [torch .rand_like (dense_input ) < sparsity ] = 0
513
485
514
486
# Enable gradient tracking
515
487
decoder .weight .requires_grad_ (True )
488
+ dense_input .requires_grad_ (True )
489
+
490
+ grad_output = torch .randn ((B , d_model ), dtype = dtype , device = "cuda" )
516
491
517
492
# Run forward pass with Triton
518
- triton_output = decode_with_triton_spmm_kernel (dense_input , decoder .weight )
493
+ triton_output = decode_with_triton_spmm_kernel (dense_input , decoder .weight , require_precise_feature_acts_grad )
519
494
assert isinstance (triton_output , torch .Tensor ), "triton_output is not a torch.Tensor"
520
- # Run forward pass with PyTorch nn.Linear
521
- torch_output = decoder (dense_input )
522
-
523
- # Generate random gradient to propagate backward
524
- grad_output = torch .randn_like (torch_output )
525
-
526
- # Backpropagate
495
+
527
496
triton_output .backward (grad_output )
497
+
498
+ triton_decoder_weight_grad , triton_dense_input_grad = decoder .weight .grad .clone (), dense_input .grad .clone () # pyright: ignore
499
+
500
+ decoder .weight .grad .zero_ () # pyright: ignore
501
+ dense_input .grad .zero_ () # pyright: ignore
502
+
503
+ torch_output = decoder (dense_input )
528
504
torch_output .backward (grad_output )
505
+
506
+ torch_decoder_weight_grad , torch_dense_input_grad = decoder .weight .grad .clone (), dense_input .grad .clone () # pyright: ignore
529
507
530
508
# Compare gradients
531
509
assert decoder .weight .grad is not None , "decoder.weight.grad is None"
532
510
assert torch .allclose (
533
- decoder .weight .grad , decoder .weight .grad , atol = 1e-4
534
- ), "Mismatch between Triton and PyTorch gradients!"
511
+ triton_output , torch_output , atol = 1e-5
512
+ ), "Mismatch between Triton and PyTorch outputs!"
513
+ assert torch .allclose (
514
+ triton_decoder_weight_grad , torch_decoder_weight_grad , atol = 1e-5
515
+ ), f"Mismatch between Triton and PyTorch gradients on decoder weights! { triton_decoder_weight_grad = } , { torch_decoder_weight_grad = } "
516
+
517
+ if require_precise_feature_acts_grad :
518
+ assert torch .allclose (
519
+ triton_dense_input_grad , torch_dense_input_grad , atol = 1e-5
520
+ ), f"Mismatch between Triton and PyTorch gradients on dense input! { triton_dense_input_grad = } , { torch_dense_input_grad = } "
521
+ else :
522
+ assert torch .allclose (
523
+ triton_dense_input_grad [dense_input .ne (0 )], torch_dense_input_grad [dense_input .ne (0 )], atol = 1e-5
524
+ ), f"Mismatch between Triton and PyTorch gradients on dense input! { triton_dense_input_grad = } , { torch_dense_input_grad = } "
535
525
536
- print ("✅ Triton backward pass matches nn.Linear!" )
526
+ print ("✅ Triton forward and backward pass matches nn.Linear!" )
537
527
538
528
# Ensure we have the Triton-based kernel
539
- def benchmark_triton_vs_torch (B = 32 , d_sae = 512 , d_model = 256 , sparsity = 0.7 , warmup = 5 , iters = 20 ):
529
+ def benchmark_triton_vs_torch (B = 32 , d_sae = 512 , d_model = 256 , sparsity = 0.7 , warmup = 5 , iters = 20 , dtype = torch . float32 , require_precise_feature_acts_grad = True ):
540
530
"""
541
531
Benchmarks Triton-based sparse-dense multiplication vs PyTorch's nn.Linear.
542
532
@@ -549,18 +539,18 @@ def benchmark_triton_vs_torch(B=32, d_sae=512, d_model=256, sparsity=0.7, warmup
549
539
"""
550
540
551
541
# Create weight matrix similar to nn.Linear
552
- decoder = nn .Linear (d_sae , d_model , bias = False , dtype = torch . float32 , device = "cuda" )
542
+ decoder = nn .Linear (d_sae , d_model , bias = False , dtype = dtype , device = "cuda" )
553
543
554
544
# Generate a dense input
555
- dense_input = torch .randn ((B , d_sae ), dtype = torch . float32 , device = "cuda" )
545
+ dense_input = torch .randn ((B , d_sae ), dtype = dtype , device = "cuda" )
556
546
557
547
# Introduce sparsity
558
548
dense_input [torch .rand_like (dense_input ) < sparsity ] = 0
559
549
560
550
# Warmup runs (to eliminate startup overhead)
561
551
for _ in range (warmup ):
562
552
torch_output = decoder (dense_input )
563
- triton_output = decode_with_triton_spmm_kernel (dense_input , decoder .weight )
553
+ triton_output = decode_with_triton_spmm_kernel (dense_input , decoder .weight , require_precise_feature_acts_grad )
564
554
assert isinstance (triton_output , torch .Tensor ), "triton_output is not a torch.Tensor"
565
555
grad_output = torch .randn_like (triton_output )
566
556
triton_output .backward (grad_output )
@@ -588,7 +578,7 @@ def benchmark_triton_vs_torch(B=32, d_sae=512, d_model=256, sparsity=0.7, warmup
588
578
589
579
start_triton .record () # type: ignore
590
580
for _ in range (iters ):
591
- triton_output = decode_with_triton_spmm_kernel (dense_input , decoder .weight )
581
+ triton_output = decode_with_triton_spmm_kernel (dense_input , decoder .weight , require_precise_feature_acts_grad )
592
582
assert isinstance (triton_output , torch .Tensor ), "triton_output is not a torch.Tensor"
593
583
grad_output = torch .randn_like (triton_output )
594
584
triton_output .backward (grad_output )
@@ -603,7 +593,6 @@ def benchmark_triton_vs_torch(B=32, d_sae=512, d_model=256, sparsity=0.7, warmup
603
593
print (f"🚀 Speedup: { torch_time / triton_time :.2f} x" )
604
594
605
595
# Run test
606
- test_triton_decoder_forward ()
607
- test_triton_decoder_backward ()
596
+ test_triton_decoder (B = 16 , d_sae = 4096 , d_model = 256 , sparsity = 0.9 , require_precise_feature_acts_grad = False )
608
597
# Run benchmark
609
- benchmark_triton_vs_torch (B = 8192 , d_sae = 4096 * 32 , d_model = 4096 , sparsity = 0.99 , warmup = 10 , iters = 100 )
598
+ # benchmark_triton_vs_torch(B=8192, d_sae=4096 * 32, d_model=4096, sparsity=0.99, warmup=10, iters=10, require_precise_feature_acts_grad=False )
0 commit comments