Skip to content

Commit 59ce8ac

Browse files
authored
Merge pull request #94 from OpenMOSS/zf_fix
feat(kernels): support spmm triton kernel for topk saes.
2 parents 834f1e3 + 01844a3 commit 59ce8ac

File tree

5 files changed

+75
-74
lines changed

5 files changed

+75
-74
lines changed

src/lm_saes/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,9 @@ class FeatureAnalyzerConfig(BaseConfig):
354354

355355
sample_weight_exponent: float = 2.0
356356
""" Exponent for weighting samples by activation value """
357+
358+
ignore_token_ids: Optional[list[int | None]] = None
359+
""" Tokens to ignore in the activations. """
357360

358361
subsamples: dict[str, dict[str, int | float]] = Field(
359362
default_factory=lambda: {"top_activations": {"proportion": 1.0, "n_samples": 10}}

src/lm_saes/crosscoder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,11 @@ def compute_loss(
292292
@torch.no_grad()
293293
def log_statistics(self):
294294
assert self.dataset_average_activation_norm is not None
295-
return {f"info/{k}": v for k, v in self.dataset_average_activation_norm.items()}
295+
log_dict = {
296+
'metrics/mean_jumprelu_threshold': all_reduce_tensor(self.log_jumprelu_threshold.exp(), aggregate='sum'),
297+
'metrics/current_l1_coefficient':self.current_l1_coefficient,
298+
}
299+
return log_dict
296300

297301
def initialize_with_same_weight_across_layers(self):
298302
self.encoder.weight.data = get_tensor_from_specific_rank(self.encoder.weight.data.clone(), src=0)

src/lm_saes/kernels.py

Lines changed: 59 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def triton_dense_dense_sparseout_matmul(
236236
assert dense2.stride(0) == 1, "dense2 must be contiguous along B"
237237

238238
if K > 512:
239-
# print("WARN - using naive matmul for large K")
239+
print("WARN - using naive matmul for large K")
240240
# naive is more efficient for large K
241241
return (dense1 @ dense2).gather(1, at_indices)
242242

@@ -378,7 +378,6 @@ def triton_sparse_dense_matmul_kernel(
378378
tl.store(out_ptr + pid * B + offsets_b, accum.to(sparse_values.dtype), mask=offsets_b < B)
379379

380380

381-
@torch.no_grad()
382381
def get_sparse_representation(x, pad_val=0):
383382
"""
384383
Efficiently extracts sparse indices and values from a batched dense tensor x.
@@ -421,33 +420,37 @@ def get_sparse_representation(x, pad_val=0):
421420

422421
class TritonDecoderAutograd(torch.autograd.Function):
423422
@staticmethod
424-
def forward(ctx, sparse_indices, sparse_values, decoder_weight):
425-
ctx.save_for_backward(sparse_indices, sparse_values, decoder_weight)
423+
def forward(ctx, feature_acts, decoder_weight, require_precise_feature_acts_grad: bool = True):
424+
sparse_indices, sparse_values = get_sparse_representation(feature_acts)
425+
ctx.save_for_backward(sparse_indices, sparse_values, decoder_weight, torch.tensor(require_precise_feature_acts_grad))
426426
return triton_sparse_dense_matmul(sparse_indices, sparse_values, decoder_weight.T)
427427

428428
@staticmethod
429429
def backward(ctx, *grad_outputs, **args):
430430
assert len(grad_outputs) == 1, "grad_outputs must be a single tensor"
431431
grad_output = grad_outputs[0]
432-
sparse_indices, sparse_values, decoder_weight = ctx.saved_tensors
432+
sparse_indices, sparse_values, decoder_weight, require_precise_feature_acts_grad = ctx.saved_tensors
433433

434434
assert grad_output.is_contiguous(), "grad_output must be contiguous; this is probably because the subsequent op was a .sum() or something like that, which returns a non contiguous gradient"
435435

436436
decoder_grad = triton_sparse_transpose_dense_matmul(
437-
sparse_indices, sparse_values, grad_output, N=decoder_weight.shape[1]
437+
sparse_indices, sparse_values, grad_output, N=decoder_weight.size(1)
438438
).T
439+
440+
if require_precise_feature_acts_grad.item():
441+
feature_acts_grad = grad_output @ decoder_weight
442+
else:
443+
feature_acts_grad_sparse = triton_dense_dense_sparseout_matmul(grad_output, decoder_weight, sparse_indices) # batch_size, K
444+
B, d_sae = sparse_indices.size(0), decoder_weight.size(1)
445+
feature_acts_grad = torch.zeros(size=(B, d_sae)).to(sparse_values).scatter_(dim=1, index=sparse_indices, src=feature_acts_grad_sparse)
439446

440-
return (
441-
None,
442-
triton_dense_dense_sparseout_matmul(grad_output, decoder_weight, sparse_indices),
443-
# decoder is contiguous when transposed so this is a matching layout
444-
decoder_grad,
445-
None,
446-
)
447+
# decoder is contiguous when transposed so this is a matching layout
448+
return feature_acts_grad, decoder_grad, None
449+
447450

448451

449452
def decode_with_triton_spmm_kernel(
450-
feature_acts: Float[torch.Tensor, "batch d_sae"], decoder_weight: Float[torch.Tensor, "d_model d_sae"]
453+
feature_acts: Float[torch.Tensor, "batch d_sae"], decoder_weight: Float[torch.Tensor, "d_model d_sae"], require_precise_feature_acts_grad: bool
451454
):
452455
"""
453456
Perform sparse-dense matrix multiplication using Triton.
@@ -459,13 +462,7 @@ def decode_with_triton_spmm_kernel(
459462
Returns:
460463
output: (B, d_model) - The decoded output.
461464
"""
462-
# Convert dense feature_acts into sparse representation
463-
sparse_indices, sparse_values = get_sparse_representation(feature_acts)
464-
465-
# Perform sparse-dense multiplication using Triton
466-
output = TritonDecoderAutograd.apply(sparse_indices, sparse_values, decoder_weight.T.contiguous().T)
467-
468-
return output
465+
return TritonDecoderAutograd.apply(feature_acts, decoder_weight.T.contiguous().T, require_precise_feature_acts_grad)
469466

470467

471468
if __name__ == "__main__":
@@ -474,69 +471,62 @@ def decode_with_triton_spmm_kernel(
474471
import triton
475472
import triton.language as tl
476473

477-
def test_triton_decoder_forward():
478-
# Set parameters
479-
B, d_sae, d_model = 4, 32, 16 # Batch size, input dim, output dim
480-
481-
# Create a random dense weight matrix (as in nn.Linear), size = (d_model, d_sae)
482-
decoder = nn.Linear(d_sae, d_model, bias=False, dtype=torch.float32, device="cuda")
483-
484-
# Create a random sparse input matrix
485-
dense_input = torch.randn((B, d_sae), dtype=torch.float32, device="cuda")
486-
487-
# Zero out some values to simulate sparsity (~70% sparsity)
488-
dense_input[torch.rand_like(dense_input) < 0.7] = 0
489-
490-
# Run our Triton-based sparse-dense multiply
491-
triton_output = decode_with_triton_spmm_kernel(dense_input, decoder.weight)
492-
493-
# Compare against standard dense multiply (nn.Linear equivalent)
494-
torch_output = decoder(dense_input) # Equivalent to nn.Linear
495-
assert isinstance(triton_output, torch.Tensor), "triton_output is not a torch.Tensor"
496-
# Ensure outputs are numerically close
497-
assert torch.allclose(triton_output, torch_output, atol=1e-4), "Mismatch between Triton and PyTorch outputs!"
498-
499-
print("✅ Triton forward pass matches nn.Linear!")
500-
501-
def test_triton_decoder_backward():
474+
def test_triton_decoder(B, d_sae, d_model, sparsity=0.9, dtype=torch.float32, require_precise_feature_acts_grad=True):
502475
# Set parameters
503-
B, d_sae, d_model = 4, 32, 16 # Batch size, input dim, output dim
504476

505477
# Create a random dense weight matrix (as in nn.Linear)
506-
decoder = nn.Linear(d_sae, d_model, bias=False, dtype=torch.float32, device="cuda")
478+
decoder = nn.Linear(d_sae, d_model, bias=False, dtype=dtype, device="cuda")
507479

508480
# Create a random sparse input matrix
509-
dense_input = torch.randn((B, d_sae), dtype=torch.float32, device="cuda")
481+
dense_input = torch.randn((B, d_sae), dtype=dtype, device="cuda")
510482

511-
# Zero out some values to simulate sparsity (~70% sparsity)
512-
dense_input[torch.rand_like(dense_input) < 0.7] = 0
483+
# Zero out some values to simulate sparsity
484+
dense_input[torch.rand_like(dense_input) < sparsity] = 0
513485

514486
# Enable gradient tracking
515487
decoder.weight.requires_grad_(True)
488+
dense_input.requires_grad_(True)
489+
490+
grad_output = torch.randn((B, d_model), dtype=dtype, device="cuda")
516491

517492
# Run forward pass with Triton
518-
triton_output = decode_with_triton_spmm_kernel(dense_input, decoder.weight)
493+
triton_output = decode_with_triton_spmm_kernel(dense_input, decoder.weight, require_precise_feature_acts_grad)
519494
assert isinstance(triton_output, torch.Tensor), "triton_output is not a torch.Tensor"
520-
# Run forward pass with PyTorch nn.Linear
521-
torch_output = decoder(dense_input)
522-
523-
# Generate random gradient to propagate backward
524-
grad_output = torch.randn_like(torch_output)
525-
526-
# Backpropagate
495+
527496
triton_output.backward(grad_output)
497+
498+
triton_decoder_weight_grad, triton_dense_input_grad = decoder.weight.grad.clone(), dense_input.grad.clone() # pyright: ignore
499+
500+
decoder.weight.grad.zero_() # pyright: ignore
501+
dense_input.grad.zero_() # pyright: ignore
502+
503+
torch_output = decoder(dense_input)
528504
torch_output.backward(grad_output)
505+
506+
torch_decoder_weight_grad, torch_dense_input_grad = decoder.weight.grad.clone(), dense_input.grad.clone() # pyright: ignore
529507

530508
# Compare gradients
531509
assert decoder.weight.grad is not None, "decoder.weight.grad is None"
532510
assert torch.allclose(
533-
decoder.weight.grad, decoder.weight.grad, atol=1e-4
534-
), "Mismatch between Triton and PyTorch gradients!"
511+
triton_output, torch_output, atol=1e-5
512+
), "Mismatch between Triton and PyTorch outputs!"
513+
assert torch.allclose(
514+
triton_decoder_weight_grad, torch_decoder_weight_grad, atol=1e-5
515+
), f"Mismatch between Triton and PyTorch gradients on decoder weights! {triton_decoder_weight_grad=}, {torch_decoder_weight_grad=}"
516+
517+
if require_precise_feature_acts_grad:
518+
assert torch.allclose(
519+
triton_dense_input_grad, torch_dense_input_grad, atol=1e-5
520+
), f"Mismatch between Triton and PyTorch gradients on dense input! {triton_dense_input_grad=}, {torch_dense_input_grad=}"
521+
else:
522+
assert torch.allclose(
523+
triton_dense_input_grad[dense_input.ne(0)], torch_dense_input_grad[dense_input.ne(0)], atol=1e-5
524+
), f"Mismatch between Triton and PyTorch gradients on dense input! {triton_dense_input_grad=}, {torch_dense_input_grad=}"
535525

536-
print("✅ Triton backward pass matches nn.Linear!")
526+
print("✅ Triton forward and backward pass matches nn.Linear!")
537527

538528
# Ensure we have the Triton-based kernel
539-
def benchmark_triton_vs_torch(B=32, d_sae=512, d_model=256, sparsity=0.7, warmup=5, iters=20):
529+
def benchmark_triton_vs_torch(B=32, d_sae=512, d_model=256, sparsity=0.7, warmup=5, iters=20, dtype=torch.float32, require_precise_feature_acts_grad=True):
540530
"""
541531
Benchmarks Triton-based sparse-dense multiplication vs PyTorch's nn.Linear.
542532
@@ -549,18 +539,18 @@ def benchmark_triton_vs_torch(B=32, d_sae=512, d_model=256, sparsity=0.7, warmup
549539
"""
550540

551541
# Create weight matrix similar to nn.Linear
552-
decoder = nn.Linear(d_sae, d_model, bias=False, dtype=torch.float32, device="cuda")
542+
decoder = nn.Linear(d_sae, d_model, bias=False, dtype=dtype, device="cuda")
553543

554544
# Generate a dense input
555-
dense_input = torch.randn((B, d_sae), dtype=torch.float32, device="cuda")
545+
dense_input = torch.randn((B, d_sae), dtype=dtype, device="cuda")
556546

557547
# Introduce sparsity
558548
dense_input[torch.rand_like(dense_input) < sparsity] = 0
559549

560550
# Warmup runs (to eliminate startup overhead)
561551
for _ in range(warmup):
562552
torch_output = decoder(dense_input)
563-
triton_output = decode_with_triton_spmm_kernel(dense_input, decoder.weight)
553+
triton_output = decode_with_triton_spmm_kernel(dense_input, decoder.weight, require_precise_feature_acts_grad)
564554
assert isinstance(triton_output, torch.Tensor), "triton_output is not a torch.Tensor"
565555
grad_output = torch.randn_like(triton_output)
566556
triton_output.backward(grad_output)
@@ -588,7 +578,7 @@ def benchmark_triton_vs_torch(B=32, d_sae=512, d_model=256, sparsity=0.7, warmup
588578

589579
start_triton.record() # type: ignore
590580
for _ in range(iters):
591-
triton_output = decode_with_triton_spmm_kernel(dense_input, decoder.weight)
581+
triton_output = decode_with_triton_spmm_kernel(dense_input, decoder.weight, require_precise_feature_acts_grad)
592582
assert isinstance(triton_output, torch.Tensor), "triton_output is not a torch.Tensor"
593583
grad_output = torch.randn_like(triton_output)
594584
triton_output.backward(grad_output)
@@ -603,7 +593,6 @@ def benchmark_triton_vs_torch(B=32, d_sae=512, d_model=256, sparsity=0.7, warmup
603593
print(f"🚀 Speedup: {torch_time / triton_time:.2f}x")
604594

605595
# Run test
606-
test_triton_decoder_forward()
607-
test_triton_decoder_backward()
596+
test_triton_decoder(B=16, d_sae=4096, d_model=256, sparsity=0.9, require_precise_feature_acts_grad=False)
608597
# Run benchmark
609-
benchmark_triton_vs_torch(B=8192, d_sae=4096 * 32, d_model=4096, sparsity=0.99, warmup=10, iters=100)
598+
# benchmark_triton_vs_torch(B=8192, d_sae=4096 * 32, d_model=4096, sparsity=0.99, warmup=10, iters=10, require_precise_feature_acts_grad=False)

src/lm_saes/runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ class GenerateActivationsSettings(BaseSettings):
155155

156156
mongo: Optional[MongoDBConfig] = None
157157
"""Configuration for the MongoDB database. If `None`, will not use the database."""
158+
159+
ignore_token_ids: Optional[list[int]] = None
160+
""" Tokens to ignore in the activations. """
158161

159162
@model_validator(mode="after")
160163
def validate_cfg(self) -> "GenerateActivationsSettings":
@@ -207,6 +210,7 @@ def generate_activations(settings: GenerateActivationsSettings) -> None:
207210
batch_size=settings.batch_size,
208211
buffer_size=settings.buffer_size,
209212
buffer_shuffle=settings.buffer_shuffle,
213+
ignore_token_ids=settings.ignore_token_ids
210214
)
211215

212216
# Configure activation writer
@@ -373,6 +377,7 @@ def train_sae(settings: TrainSAESettings) -> None:
373377
eval_fn = (lambda x: None) if settings.eval else None
374378

375379
trainer = Trainer(settings.trainer)
380+
sae.cfg.save_hyperparameters(settings.trainer.exp_result_path)
376381
trainer.fit(sae=sae, activation_stream=activations_stream, eval_fn=eval_fn, wandb_logger=wandb_logger)
377382
sae.save_pretrained(
378383
save_path=settings.trainer.exp_result_path,

src/lm_saes/sae.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,6 @@ def save_pretrained(
367367
# TODO: save dataset_average_activation_norm
368368
self.save_checkpoint(save_path)
369369
if self.device_mesh is None or self.device_mesh.get_rank() == 0:
370-
self.cfg.save_hyperparameters(save_path)
371370
if mongo_client is not None:
372371
assert (
373372
sae_name is not None and sae_series is not None
@@ -489,8 +488,9 @@ def decode(
489488
]: # may be overridden by subclasses
490489
max_l0_in_batch = feature_acts.gt(0).to(feature_acts).sum(dim=-1).max()
491490
sparsity_threshold = self.cfg.d_sae * (1 - self.cfg.sparsity_threshold_for_triton_spmm_kernel)
492-
if self.cfg.use_triton_kernel and max_l0_in_batch < sparsity_threshold:
493-
reconstructed = decode_with_triton_spmm_kernel(feature_acts, self.decoder.weight)
491+
if self.cfg.use_triton_kernel and 0 < max_l0_in_batch < sparsity_threshold: # triton kernel cannot handle empty feature_acts
492+
require_precise_feature_acts_grad = "topk" not in self.cfg.act_fn
493+
reconstructed = decode_with_triton_spmm_kernel(feature_acts, self.decoder.weight, require_precise_feature_acts_grad)
494494
else:
495495
reconstructed = self.decoder(feature_acts)
496496
reconstructed = self.hook_reconstructed(reconstructed)

0 commit comments

Comments
 (0)