Skip to content

Commit 80acaeb

Browse files
authored
[Mamba] Fix errors in Triton backend (#576)
1 parent 6b7df98 commit 80acaeb

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

fla/layers/mamba2.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,20 @@ def cuda_kernels_forward(
340340
self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
341341
)
342342
else:
343-
hidden_states_B_C = self.causal_conv1d_fn(
343+
_conv1d_output = self.causal_conv1d_fn(
344344
x=hidden_states_B_C.transpose(1, 2).contiguous(),
345345
weight=self.conv1d.weight.squeeze(1),
346346
bias=self.conv1d.bias,
347347
activation=self.activation,
348-
).transpose(1, 2)
348+
)
349+
if self.backend == 'cuda':
350+
hidden_states_B_C = _conv1d_output
351+
hidden_states_B_C = hidden_states_B_C.transpose(1, 2)
352+
elif self.backend == 'triton':
353+
hidden_states_B_C, _ = _conv1d_output
354+
hidden_states_B_C = hidden_states_B_C.transpose(1, 2).contiguous()
355+
else:
356+
raise ValueError(f"Unsupported backend: {self.backend}")
349357

350358
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
351359
hidden_states, B, C = torch.split(

0 commit comments

Comments
 (0)