Skip to content

Commit eef117d

Browse files
vanaka11IvanFilipov
authored andcommitted
ggml: add support for float16 input tensors in pooling operations (ggml/895)
* Add support for float16 tensors in 1d pooling operations * Add support for float16 input tensors in 2d pooling operations * code cleanup remove unnecessary casting during srow ptr initialization --------- Co-authored-by: vanaka11 <vanaka1189@gmail.com>
1 parent 27e0c1d commit eef117d

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

ggml/src/ggml.c

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14746,7 +14746,7 @@ static void ggml_compute_forward_pool_1d_sk_p0(
1474614746

1474714747
const struct ggml_tensor * src = dst->src[0];
1474814748

14749-
assert(src->type == GGML_TYPE_F32);
14749+
assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
1475014750

1475114751
if (params->ith != 0) {
1475214752
return;
@@ -14759,21 +14759,20 @@ static void ggml_compute_forward_pool_1d_sk_p0(
1475914759
const int64_t rs = dst->ne[0];
1476014760

1476114761
while (cdata < data_end) {
14762-
const float * const srow = (const float *)cdata;
14763-
14762+
const void * srow = (const void *)cdata;
1476414763
int j = 0;
14765-
1476614764
for (int64_t i = 0; i < rs; ++i) {
1476714765
switch (op) {
1476814766
case GGML_OP_POOL_AVG: drow[i] = 0; break;
1476914767
case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
1477014768
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
1477114769
}
1477214770
for (int ki = 0; ki < k; ++ki) {
14771+
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
1477314772
switch (op) {
14774-
case GGML_OP_POOL_AVG: drow[i] += srow[j]; break;
14775-
case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break;
14776-
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
14773+
case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
14774+
case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
14775+
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
1477714776
}
1477814777
++j;
1477914778
}
@@ -14814,7 +14813,7 @@ static void ggml_compute_forward_pool_2d(
1481414813

1481514814
const struct ggml_tensor * src = dst->src[0];
1481614815

14817-
GGML_ASSERT(src->type == GGML_TYPE_F32);
14816+
assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
1481814817

1481914818
if (params->ith != 0) {
1482014819
return;
@@ -14857,14 +14856,15 @@ static void ggml_compute_forward_pool_2d(
1485714856

1485814857
for (int ky = 0; ky < k1; ++ky) {
1485914858
if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
14860-
const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
14859+
const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
1486114860
for (int kx = 0; kx < k0; ++kx) {
1486214861
int j = ix + kx;
1486314862
if (j < 0 || j >= src->ne[0]) continue;
14863+
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
1486414864
switch (op) {
14865-
case GGML_OP_POOL_AVG: *out += srow[j]; break;
14866-
case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
14867-
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
14865+
case GGML_OP_POOL_AVG: *out += srow_j; break;
14866+
case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
14867+
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
1486814868
}
1486914869
}
1487014870
}

0 commit comments

Comments
 (0)