@@ -2679,13 +2679,18 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
2679
2679
bool use_quantized_src1 = false ;
2680
2680
int64_t src1_padded_num_cols = 0 , src1_padded_row_size = 0 , src1_quantized_size = 0 ;
2681
2681
if (ggml_is_quantized (src0_1->type ) && src0_1->type == src0_2->type && src1->ne [1 ] == 1 && src1->ne [3 ] == 1 ) {
2682
- src1_padded_num_cols = GGML_PAD (src1->ne [0 ], MATRIX_ROW_PADDING);
2683
- src1_padded_row_size = src1_padded_num_cols/ggml_blck_size (GGML_TYPE_Q8_1)*ggml_type_size (GGML_TYPE_Q8_1);
2684
- src1_quantized_size = src1_padded_row_size*src1->ne [2 ] + get_mmq_x_max_host (ggml_cuda_info ().devices [ctx.device ].cc )*sizeof (block_q8_1_mmq);
2685
- src1_quantized.alloc (src1_quantized_size);
2686
- use_quantized_src1 = true ;
2682
+ if (ggml_cuda_should_use_mmq (src0_1->type , ggml_cuda_info ().devices [ctx.device ].cc , src1->ne [2 ])) {
2683
+ src1_padded_num_cols = GGML_PAD (src1->ne [0 ], MATRIX_ROW_PADDING);
2684
+ src1_padded_row_size = src1_padded_num_cols/ggml_blck_size (GGML_TYPE_Q8_1)*ggml_type_size (GGML_TYPE_Q8_1);
2685
+ src1_quantized_size = src1_padded_row_size*src1->ne [2 ] + get_mmq_x_max_host (ggml_cuda_info ().devices [ctx.device ].cc )*sizeof (block_q8_1_mmq);
2686
+ src1_quantized.alloc (src1_quantized_size);
2687
+ use_quantized_src1 = true ;
2688
+ }
2689
+ }
2690
+ ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool ());
2691
+ if (!use_quantized_src1) {
2692
+ src1_contiguous.alloc (sizeof (float )*ggml_nelements (src1));
2687
2693
}
2688
- ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (src1));
2689
2694
ggml_cuda_pool_alloc<char > dst_up_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (dst));
2690
2695
ggml_cuda_pool_alloc<char > dst_gate_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (dst));
2691
2696
ggml_cuda_pool_alloc<char > final_dst_contiguous (ctx.pool ());
@@ -2728,6 +2733,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
2728
2733
k_copy_src_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2729
2734
src1_original, src1_contiguous.get (), dev_row_mapping.get () + mapping_offset, ne10, ne11, nb11, nb12);
2730
2735
CUDA_CHECK (cudaGetLastError ());
2736
+ src1_row.data = src1_contiguous.get ();
2731
2737
}
2732
2738
2733
2739
src0_1_row.data = src0_1_original + i02*nb02;
0 commit comments