@@ -5974,6 +5974,8 @@ static void ggml_compute_forward_add_q_f32(
5974
5974
const int ir0 = dr * ith ;
5975
5975
const int ir1 = MIN (ir0 + dr , nr );
5976
5976
5977
+ float * wdata = (float * ) params -> wdata + ne00 * ith ;
5978
+
5977
5979
for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
5978
5980
// src0 indices
5979
5981
const int i03 = ir /(ne02 * ne01 );
@@ -5996,12 +5998,11 @@ static void ggml_compute_forward_add_q_f32(
5996
5998
assert (ne00 % 32 == 0 );
5997
5999
5998
6000
// unquantize row from src0 to temp buffer
5999
- float tmp [ne00 ];
6000
- dequantize_row_q (src0_row , tmp , ne00 );
6001
+ dequantize_row_q (src0_row , wdata , ne00 );
6001
6002
// add src1
6002
- ggml_vec_acc_f32 (ne00 , tmp , src1_row );
6003
+ ggml_vec_acc_f32 (ne00 , wdata , src1_row );
6003
6004
// quantize row to dst
6004
- quantize_row_q (tmp , dst_row , ne00 );
6005
+ quantize_row_q (wdata , dst_row , ne00 );
6005
6006
}
6006
6007
}
6007
6008
@@ -10198,6 +10199,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10198
10199
case GGML_OP_ADD :
10199
10200
{
10200
10201
node -> n_tasks = n_threads ;
10202
+
10203
+ size_t cur = 0 ;
10204
+
10205
+ if (node -> src0 -> type == GGML_TYPE_Q4_0 || node -> src0 -> type == GGML_TYPE_Q4_1 ) {
10206
+ cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ] * node -> src0 -> ne [0 ] * n_threads ;
10207
+ }
10208
+
10209
+ work_size = MAX (work_size , cur );
10201
10210
} break ;
10202
10211
case GGML_OP_SUB :
10203
10212
case GGML_OP_MUL :
0 commit comments