@@ -5481,6 +5481,8 @@ static void ggml_compute_forward_add_q_f32(
5481
5481
const int ir0 = dr * ith ;
5482
5482
const int ir1 = MIN (ir0 + dr , nr );
5483
5483
5484
+ float * wdata = (float * ) params -> wdata + ne00 * ith ;
5485
+
5484
5486
for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
5485
5487
// src0 indices
5486
5488
const int i03 = ir /(ne02 * ne01 );
@@ -5503,12 +5505,11 @@ static void ggml_compute_forward_add_q_f32(
5503
5505
assert (ne00 % 32 == 0 );
5504
5506
5505
5507
// unquantize row from src0 to temp buffer
5506
- float tmp [ne00 ];
5507
- dequantize_row_q (src0_row , tmp , ne00 );
5508
+ dequantize_row_q (src0_row , wdata , ne00 );
5508
5509
// add src1
5509
- ggml_vec_acc_f32 (ne00 , tmp , src1_row );
5510
+ ggml_vec_acc_f32 (ne00 , wdata , src1_row );
5510
5511
// quantize row to dst
5511
- quantize_row_q (tmp , dst_row , ne00 );
5512
+ quantize_row_q (wdata , dst_row , ne00 );
5512
5513
}
5513
5514
}
5514
5515
@@ -9715,6 +9716,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9715
9716
case GGML_OP_ADD :
9716
9717
{
9717
9718
node -> n_tasks = n_threads ;
9719
+
9720
+ size_t cur = 0 ;
9721
+
9722
+ if (node -> src0 -> type == GGML_TYPE_Q4_0 || node -> src0 -> type == GGML_TYPE_Q4_1 ) {
9723
+ cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ] * node -> src0 -> ne [0 ] * n_threads ;
9724
+ }
9725
+
9726
+ work_size = MAX (work_size , cur );
9718
9727
} break ;
9719
9728
case GGML_OP_SUB :
9720
9729
case GGML_OP_MUL :
0 commit comments