@@ -14746,7 +14746,7 @@ static void ggml_compute_forward_pool_1d_sk_p0(
14746
14746
14747
14747
const struct ggml_tensor * src = dst->src[0];
14748
14748
14749
- assert(src->type == GGML_TYPE_F32);
14749
+ assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 );
14750
14750
14751
14751
if (params->ith != 0) {
14752
14752
return;
@@ -14759,21 +14759,20 @@ static void ggml_compute_forward_pool_1d_sk_p0(
14759
14759
const int64_t rs = dst->ne[0];
14760
14760
14761
14761
while (cdata < data_end) {
14762
- const float * const srow = (const float *)cdata;
14763
-
14762
+ const void * srow = (const void *)cdata;
14764
14763
int j = 0;
14765
-
14766
14764
for (int64_t i = 0; i < rs; ++i) {
14767
14765
switch (op) {
14768
14766
case GGML_OP_POOL_AVG: drow[i] = 0; break;
14769
14767
case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
14770
14768
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
14771
14769
}
14772
14770
for (int ki = 0; ki < k; ++ki) {
14771
+ const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
14773
14772
switch (op) {
14774
- case GGML_OP_POOL_AVG: drow[i] += srow[j] ; break;
14775
- case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j] ; break;
14776
- case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
14773
+ case GGML_OP_POOL_AVG: drow[i] += srow_j ; break;
14774
+ case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j ; break;
14775
+ case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
14777
14776
}
14778
14777
++j;
14779
14778
}
@@ -14814,7 +14813,7 @@ static void ggml_compute_forward_pool_2d(
14814
14813
14815
14814
const struct ggml_tensor * src = dst->src[0];
14816
14815
14817
- GGML_ASSERT (src->type == GGML_TYPE_F32);
14816
+ assert (src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 );
14818
14817
14819
14818
if (params->ith != 0) {
14820
14819
return;
@@ -14857,14 +14856,15 @@ static void ggml_compute_forward_pool_2d(
14857
14856
14858
14857
for (int ky = 0; ky < k1; ++ky) {
14859
14858
if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
14860
- const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
14859
+ const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
14861
14860
for (int kx = 0; kx < k0; ++kx) {
14862
14861
int j = ix + kx;
14863
14862
if (j < 0 || j >= src->ne[0]) continue;
14863
+ const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
14864
14864
switch (op) {
14865
- case GGML_OP_POOL_AVG: *out += srow[j] ; break;
14866
- case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j] ; break;
14867
- case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
14865
+ case GGML_OP_POOL_AVG: *out += srow_j ; break;
14866
+ case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j ; break;
14867
+ case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
14868
14868
}
14869
14869
}
14870
14870
}
0 commit comments