Skip to content

Commit 7509a1f

Browse files
committed
fix(kernel): fix speed degradation of TopK kernel
1 parent 44bb55b commit 7509a1f

File tree

1 file changed

+34
-14
lines changed

1 file changed

+34
-14
lines changed

src/lm_saes/kernels.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -418,34 +418,49 @@ def get_sparse_representation(x, pad_val=0):
418418
return sparse_indices, sparse_values
419419

420420

421-
class TritonDecoderAutograd(torch.autograd.Function):
421+
class TritonDecoderAutogradJumpReLU(torch.autograd.Function):
422422
@staticmethod
423-
def forward(ctx, feature_acts, decoder_weight, require_precise_feature_acts_grad: bool = True):
423+
def forward(ctx, feature_acts, decoder_weight):
424424
sparse_indices, sparse_values = get_sparse_representation(feature_acts)
425-
ctx.save_for_backward(sparse_indices, sparse_values, decoder_weight, torch.tensor(require_precise_feature_acts_grad))
425+
ctx.save_for_backward(sparse_indices, sparse_values, decoder_weight)
426426
return triton_sparse_dense_matmul(sparse_indices, sparse_values, decoder_weight.T)
427427

428428
@staticmethod
429429
def backward(ctx, *grad_outputs, **args):
430430
assert len(grad_outputs) == 1, "grad_outputs must be a single tensor"
431431
grad_output = grad_outputs[0]
432-
sparse_indices, sparse_values, decoder_weight, require_precise_feature_acts_grad = ctx.saved_tensors
432+
sparse_indices, sparse_values, decoder_weight = ctx.saved_tensors
433433

434434
assert grad_output.is_contiguous(), "grad_output must be contiguous; this is probably because the subsequent op was a .sum() or something like that, which returns a non contiguous gradient"
435435

436436
decoder_grad = triton_sparse_transpose_dense_matmul(
437437
sparse_indices, sparse_values, grad_output, N=decoder_weight.size(1)
438438
).T
439439

440-
if require_precise_feature_acts_grad.item():
441-
feature_acts_grad = grad_output @ decoder_weight
442-
else:
443-
feature_acts_grad_sparse = triton_dense_dense_sparseout_matmul(grad_output, decoder_weight, sparse_indices) # batch_size, K
444-
B, d_sae = sparse_indices.size(0), decoder_weight.size(1)
445-
feature_acts_grad = torch.zeros(size=(B, d_sae)).to(sparse_values).scatter_(dim=1, index=sparse_indices, src=feature_acts_grad_sparse)
440+
# decoder is contiguous when transposed so this is a matching layout
441+
return grad_output @ decoder_weight, decoder_grad
442+
443+
444+
class TritonDecoderAutogradTopK(torch.autograd.Function):
445+
@staticmethod
446+
def forward(ctx, sparse_indices, sparse_values, decoder_weight):
447+
ctx.save_for_backward(sparse_indices, sparse_values, decoder_weight)
448+
return triton_sparse_dense_matmul(sparse_indices, sparse_values, decoder_weight.T)
446449

450+
@staticmethod
451+
def backward(ctx, *grad_outputs, **args):
452+
assert len(grad_outputs) == 1, "grad_outputs must be a single tensor"
453+
grad_output = grad_outputs[0]
454+
sparse_indices, sparse_values, decoder_weight = ctx.saved_tensors
455+
456+
assert grad_output.is_contiguous(), "grad_output must be contiguous; this is probably because the subsequent op was a .sum() or something like that, which returns a non contiguous gradient"
457+
458+
decoder_grad = triton_sparse_transpose_dense_matmul(
459+
sparse_indices, sparse_values, grad_output, N=decoder_weight.size(1)
460+
).T
461+
447462
# decoder is contiguous when transposed so this is a matching layout
448-
return feature_acts_grad, decoder_grad, None
463+
return None, triton_dense_dense_sparseout_matmul(grad_output, decoder_weight, sparse_indices), decoder_grad
449464

450465

451466

@@ -462,7 +477,12 @@ def decode_with_triton_spmm_kernel(
462477
Returns:
463478
output: (B, d_model) - The decoded output.
464479
"""
465-
return TritonDecoderAutograd.apply(feature_acts, decoder_weight.T.contiguous().T, require_precise_feature_acts_grad)
480+
if require_precise_feature_acts_grad:
481+
output = TritonDecoderAutogradJumpReLU.apply(feature_acts, decoder_weight.T.contiguous().T)
482+
else:
483+
sparse_indices, sparse_values = get_sparse_representation(feature_acts)
484+
output = TritonDecoderAutogradTopK.apply(sparse_indices, sparse_values, decoder_weight.T.contiguous().T)
485+
return output
466486

467487

468488
if __name__ == "__main__":
@@ -593,6 +613,6 @@ def benchmark_triton_vs_torch(B=32, d_sae=512, d_model=256, sparsity=0.7, warmup
593613
print(f"🚀 Speedup: {torch_time / triton_time:.2f}x")
594614

595615
# Run test
596-
test_triton_decoder(B=16, d_sae=4096, d_model=256, sparsity=0.9, require_precise_feature_acts_grad=False)
616+
test_triton_decoder(B=16, d_sae=4096, d_model=256, sparsity=0.9, require_precise_feature_acts_grad=True)
597617
# Run benchmark
598-
# benchmark_triton_vs_torch(B=8192, d_sae=4096 * 32, d_model=4096, sparsity=0.99, warmup=10, iters=10, require_precise_feature_acts_grad=False)
618+
benchmark_triton_vs_torch(B=8192, d_sae=4096 * 32, d_model=4096, sparsity=0.999, warmup=10, iters=10, require_precise_feature_acts_grad=True)

0 commit comments

Comments
 (0)