@@ -678,25 +678,17 @@ void launch_fattn(
678
678
) {
679
679
constexpr int ncols = ncols1 * ncols2;
680
680
681
- const bool is_mla = DV == 512 ; // TODO better parameterization
682
-
683
681
const ggml_tensor * Q = dst->src [0 ];
684
682
const ggml_tensor * K = dst->src [1 ];
685
683
const ggml_tensor * V = dst->src [2 ];
686
684
687
- GGML_ASSERT (V || is_mla);
688
-
689
685
const ggml_tensor * mask = dst->src [3 ];
690
686
691
687
ggml_tensor * KQV = dst;
692
688
693
689
GGML_ASSERT (Q->type == GGML_TYPE_F32);
694
690
GGML_ASSERT (KQV->type == GGML_TYPE_F32);
695
691
696
- GGML_ASSERT ( Q->nb [0 ] == ggml_element_size (Q));
697
- GGML_ASSERT ( K->nb [0 ] == ggml_element_size (K));
698
- GGML_ASSERT (!V || V->nb [0 ] == ggml_element_size (V));
699
-
700
692
GGML_ASSERT (!mask || mask->type == GGML_TYPE_F16);
701
693
GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
702
694
" the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
@@ -721,10 +713,10 @@ void launch_fattn(
721
713
size_t nb12 = K->nb [2 ];
722
714
size_t nb13 = K->nb [3 ];
723
715
724
- const char * V_data = V ? (const char *) V->data : nullptr ;
725
- size_t nb21 = V ? V ->nb [1 ] : nb11 ;
726
- size_t nb22 = V ? V ->nb [2 ] : nb12 ;
727
- size_t nb23 = V ? V ->nb [3 ] : nb13 ;
716
+ const char * V_data = (const char *) V->data ;
717
+ size_t nb21 = V->nb [1 ];
718
+ size_t nb22 = V->nb [2 ];
719
+ size_t nb23 = V->nb [3 ];
728
720
729
721
if (need_f16_K && K->type != GGML_TYPE_F16) {
730
722
K_f16.alloc (ggml_nelements (K));
@@ -740,8 +732,7 @@ void launch_fattn(
740
732
nb13 = nb13*bs*sizeof (half)/ts;
741
733
}
742
734
743
- if (V && need_f16_V && V->type != GGML_TYPE_F16) {
744
- GGML_ASSERT (ggml_is_contiguously_allocated (V));
735
+ if (need_f16_V && V->type != GGML_TYPE_F16) {
745
736
V_f16.alloc (ggml_nelements (V));
746
737
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (V->type );
747
738
to_fp16 (V_data, V_f16.ptr , ggml_nelements (V), main_stream);
0 commit comments