File tree Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -340,12 +340,20 @@ def cuda_kernels_forward(
340340 self .conv1d (hidden_states_B_C .transpose (1 , 2 ))[..., :seq_len ].transpose (1 , 2 )
341341 )
342342 else :
343- hidden_states_B_C = self .causal_conv1d_fn (
343+ _conv1d_output = self .causal_conv1d_fn (
344344 x = hidden_states_B_C .transpose (1 , 2 ).contiguous (),
345345 weight = self .conv1d .weight .squeeze (1 ),
346346 bias = self .conv1d .bias ,
347347 activation = self .activation ,
348- ).transpose (1 , 2 )
348+ )
349+ if self .backend == 'cuda' :
350+ hidden_states_B_C = _conv1d_output
351+ hidden_states_B_C = hidden_states_B_C .transpose (1 , 2 )
352+ elif self .backend == 'triton' :
353+ hidden_states_B_C , _ = _conv1d_output
354+ hidden_states_B_C = hidden_states_B_C .transpose (1 , 2 ).contiguous ()
355+ else :
356+ raise ValueError (f"Unsupported backend: { self .backend } " )
349357
350358 hidden_states_B_C = apply_mask_to_padding_states (hidden_states_B_C , attention_mask )
351359 hidden_states , B , C = torch .split (
You can’t perform that action at this time.
0 commit comments