@@ -418,34 +418,49 @@ def get_sparse_representation(x, pad_val=0):
418
418
return sparse_indices , sparse_values
419
419
420
420
421
- class TritonDecoderAutograd (torch .autograd .Function ):
421
+ class TritonDecoderAutogradJumpReLU (torch .autograd .Function ):
422
422
@staticmethod
423
- def forward (ctx , feature_acts , decoder_weight , require_precise_feature_acts_grad : bool = True ):
423
+ def forward (ctx , feature_acts , decoder_weight ):
424
424
sparse_indices , sparse_values = get_sparse_representation (feature_acts )
425
- ctx .save_for_backward (sparse_indices , sparse_values , decoder_weight , torch . tensor ( require_precise_feature_acts_grad ) )
425
+ ctx .save_for_backward (sparse_indices , sparse_values , decoder_weight )
426
426
return triton_sparse_dense_matmul (sparse_indices , sparse_values , decoder_weight .T )
427
427
428
428
@staticmethod
429
429
def backward (ctx , * grad_outputs , ** args ):
430
430
assert len (grad_outputs ) == 1 , "grad_outputs must be a single tensor"
431
431
grad_output = grad_outputs [0 ]
432
- sparse_indices , sparse_values , decoder_weight , require_precise_feature_acts_grad = ctx .saved_tensors
432
+ sparse_indices , sparse_values , decoder_weight = ctx .saved_tensors
433
433
434
434
assert grad_output .is_contiguous (), "grad_output must be contiguous; this is probably because the subsequent op was a .sum() or something like that, which returns a non contiguous gradient"
435
435
436
436
decoder_grad = triton_sparse_transpose_dense_matmul (
437
437
sparse_indices , sparse_values , grad_output , N = decoder_weight .size (1 )
438
438
).T
439
439
440
- if require_precise_feature_acts_grad .item ():
441
- feature_acts_grad = grad_output @ decoder_weight
442
- else :
443
- feature_acts_grad_sparse = triton_dense_dense_sparseout_matmul (grad_output , decoder_weight , sparse_indices ) # batch_size, K
444
- B , d_sae = sparse_indices .size (0 ), decoder_weight .size (1 )
445
- feature_acts_grad = torch .zeros (size = (B , d_sae )).to (sparse_values ).scatter_ (dim = 1 , index = sparse_indices , src = feature_acts_grad_sparse )
440
+ # decoder is contiguous when transposed so this is a matching layout
441
+ return grad_output @ decoder_weight , decoder_grad
442
+
443
+
444
+ class TritonDecoderAutogradTopK (torch .autograd .Function ):
445
+ @staticmethod
446
+ def forward (ctx , sparse_indices , sparse_values , decoder_weight ):
447
+ ctx .save_for_backward (sparse_indices , sparse_values , decoder_weight )
448
+ return triton_sparse_dense_matmul (sparse_indices , sparse_values , decoder_weight .T )
446
449
450
+ @staticmethod
451
+ def backward (ctx , * grad_outputs , ** args ):
452
+ assert len (grad_outputs ) == 1 , "grad_outputs must be a single tensor"
453
+ grad_output = grad_outputs [0 ]
454
+ sparse_indices , sparse_values , decoder_weight = ctx .saved_tensors
455
+
456
+ assert grad_output .is_contiguous (), "grad_output must be contiguous; this is probably because the subsequent op was a .sum() or something like that, which returns a non contiguous gradient"
457
+
458
+ decoder_grad = triton_sparse_transpose_dense_matmul (
459
+ sparse_indices , sparse_values , grad_output , N = decoder_weight .size (1 )
460
+ ).T
461
+
447
462
# decoder is contiguous when transposed so this is a matching layout
448
- return feature_acts_grad , decoder_grad , None
463
+ return None , triton_dense_dense_sparseout_matmul ( grad_output , decoder_weight , sparse_indices ), decoder_grad
449
464
450
465
451
466
@@ -462,7 +477,12 @@ def decode_with_triton_spmm_kernel(
462
477
Returns:
463
478
output: (B, d_model) - The decoded output.
464
479
"""
465
- return TritonDecoderAutograd .apply (feature_acts , decoder_weight .T .contiguous ().T , require_precise_feature_acts_grad )
480
+ if require_precise_feature_acts_grad :
481
+ output = TritonDecoderAutogradJumpReLU .apply (feature_acts , decoder_weight .T .contiguous ().T )
482
+ else :
483
+ sparse_indices , sparse_values = get_sparse_representation (feature_acts )
484
+ output = TritonDecoderAutogradTopK .apply (sparse_indices , sparse_values , decoder_weight .T .contiguous ().T )
485
+ return output
466
486
467
487
468
488
if __name__ == "__main__" :
@@ -593,6 +613,6 @@ def benchmark_triton_vs_torch(B=32, d_sae=512, d_model=256, sparsity=0.7, warmup
593
613
print (f"🚀 Speedup: { torch_time / triton_time :.2f} x" )
594
614
595
615
# Run test
596
- test_triton_decoder (B = 16 , d_sae = 4096 , d_model = 256 , sparsity = 0.9 , require_precise_feature_acts_grad = False )
616
+ test_triton_decoder (B = 16 , d_sae = 4096 , d_model = 256 , sparsity = 0.9 , require_precise_feature_acts_grad = True )
597
617
# Run benchmark
598
- # benchmark_triton_vs_torch(B=8192, d_sae=4096 * 32, d_model=4096, sparsity=0.99 , warmup=10, iters=10, require_precise_feature_acts_grad=False )
618
+ benchmark_triton_vs_torch (B = 8192 , d_sae = 4096 * 32 , d_model = 4096 , sparsity = 0.999 , warmup = 10 , iters = 10 , require_precise_feature_acts_grad = True )
0 commit comments