Skip to content

Commit 447e7b4

Browse files
committed
Support F16 operations
1 parent 96ded47 commit 447e7b4

File tree

2 files changed

+36
-33
lines changed

2 files changed

+36
-33
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6117,29 +6117,29 @@ void ggml_compute_forward_im2col_back_f32(
61176117
}
61186118
}
61196119

6120-
static void ggml_call_mul_mat(
6121-
const ggml_compute_params * params,
6122-
int64_t m, int64_t n, int64_t k,
6123-
void * a, void * b, void * c) {
6124-
6120+
static void ggml_call_mul_mat(ggml_type T, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6121+
void * a, void * b, void * c) {
6122+
const ggml_type_traits * traits = ggml_get_type_traits(T);
61256123
struct ggml_tensor src1 = {};
6124+
src1.type = T;
61266125
src1.ne[0] = k;
61276126
src1.ne[1] = m;
61286127
src1.ne[2] = 1;
61296128
src1.ne[3] = 1;
6130-
src1.nb[0] = sizeof(float);
6131-
src1.nb[1] = k * sizeof(float);
6129+
src1.nb[0] = traits->type_size;
6130+
src1.nb[1] = k * traits->type_size;
61326131
src1.nb[2] = src1.nb[1];
61336132
src1.nb[3] = src1.nb[2];
61346133
src1.data = a;
61356134

61366135
struct ggml_tensor src0 = {};
6136+
src0.type = T;
61376137
src0.ne[0] = k;
61386138
src0.ne[1] = n;
61396139
src0.ne[2] = 1;
61406140
src0.ne[3] = 1;
6141-
src0.nb[0] = sizeof(float);
6142-
src0.nb[1] = k * sizeof(float);
6141+
src0.nb[0] = traits->type_size;
6142+
src0.nb[1] = k * traits->type_size;
61436143
src0.nb[2] = src0.nb[1];
61446144
src0.nb[3] = src0.nb[2];
61456145
src0.data = b;
@@ -6160,17 +6160,18 @@ static void ggml_call_mul_mat(
61606160
ggml_compute_forward_mul_mat(params, &dst);
61616161
}
61626162

6163-
61646163
// ggml_compute_forward_conv_2d
61656164

6166-
static void ggml_compute_forward_conv_2d_f32(
6167-
const ggml_compute_params * params,
6168-
const ggml_tensor * kernel, // [KW, KH, IC, OC] - fp32
6169-
const ggml_tensor * src, // [W, H, C, N]
6170-
ggml_tensor * dst) { // [OW, OH, OC, N]
6165+
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6166+
const ggml_tensor * kernel, // [KW, KH, IC, OC]
6167+
const ggml_tensor * src, // [W, H, C, N]
6168+
ggml_tensor * dst, // [OW, OH, OC, N]
6169+
ggml_type kernel_type) {
61716170

61726171
GGML_ASSERT(ggml_is_contiguous(kernel));
6173-
GGML_ASSERT(kernel->type == GGML_TYPE_F32);
6172+
GGML_ASSERT(kernel->type == kernel_type);
6173+
6174+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
61746175

61756176
const int32_t stride_x = dst->op_params[0];
61766177
const int32_t stride_y = dst->op_params[1];
@@ -6191,20 +6192,20 @@ static void ggml_compute_forward_conv_2d_f32(
61916192
const int64_t dst_h = dst->ne[1];
61926193

61936194
float * src_data = (float*) src->data;
6194-
float * knl_data = (float*) kernel->data;
6195+
void * knl_data = kernel->data;
61956196
float * dst_data = (float*) dst->data;
61966197

61976198
const int64_t knl_n = knl_w * knl_h * c_in;
61986199
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
61996200

6200-
const int64_t space_per_patch = knl_n * sizeof(float) + c_out * sizeof(float);
6201+
const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
62016202
const int64_t batch_size = params->wsize / space_per_patch;
62026203
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
62036204
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
62046205

62056206
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
62066207

6207-
float * tmp = (float *) params->wdata;
6208+
void * tmp = params->wdata;
62086209

62096210
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
62106211

@@ -6224,7 +6225,7 @@ static void ggml_compute_forward_conv_2d_f32(
62246225
const int64_t src_y = p % dst_w;
62256226

62266227
float * src_base = (float *)((char *)src_data + batch_n * src->nb[3]);
6227-
float * dst_row = tmp + (p % patches_per_batch) * knl_n;
6228+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
62286229

62296230
for (int64_t ic = 0; ic < c_in; ++ic) {
62306231
for (int64_t ky = 0; ky < knl_h; ++ky) {
@@ -6234,11 +6235,19 @@ static void ggml_compute_forward_conv_2d_f32(
62346235

62356236
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
62366237

6238+
float src_val;
62376239
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6238-
dst_row[dst_idx] = 0.0f;
6240+
src_val = 0.0f;
62396241
} else {
62406242
float * src_ptr = (float *)((char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6241-
dst_row[dst_idx] = *src_ptr;
6243+
src_val = *src_ptr;
6244+
}
6245+
6246+
char * element_ptr = dst_row + dst_idx * traits->type_size;
6247+
if (kernel_type == GGML_TYPE_F32) {
6248+
*(float *) element_ptr = src_val;
6249+
} else if (kernel_type == GGML_TYPE_F16) {
6250+
*(ggml_fp16_t *) element_ptr = GGML_FP32_TO_FP16(src_val);
62426251
}
62436252
}
62446253
}
@@ -6247,11 +6256,10 @@ static void ggml_compute_forward_conv_2d_f32(
62476256

62486257
ggml_barrier(params->threadpool);
62496258

6250-
float * gemm_output = tmp + patches_per_batch * knl_n;
6259+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
62516260

62526261
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6253-
ggml_call_mul_mat(params, patch_n, c_out, knl_n,
6254-
tmp, knl_data, gemm_output);
6262+
ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
62556263

62566264
ggml_barrier(params->threadpool);
62576265

@@ -6269,7 +6277,6 @@ static void ggml_compute_forward_conv_2d_f32(
62696277

62706278
for (int64_t oc = 0; oc < c_out; ++oc) {
62716279
const float value = gemm_output[i * c_out + oc];
6272-
// Write to WHCN layout: dst[w, h, c, n]
62736280
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
62746281
*dst_ptr = value;
62756282
}
@@ -6284,11 +6291,7 @@ void ggml_compute_forward_conv_2d(
62846291
const ggml_tensor * src0 = dst->src[0];
62856292
const ggml_tensor * src1 = dst->src[1];
62866293

6287-
if (src0->type == GGML_TYPE_F16) {
6288-
GGML_ASSERT(false && "F16 not supported yet");
6289-
} else {
6290-
ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
6291-
}
6294+
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
62926295
}
62936296

62946297
// ggml_compute_forward_conv_transpose_2d

ggml/src/ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
985985
"OPT_STEP_ADAMW",
986986
};
987987

988-
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
988+
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
989989

990990
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
991991
"none",
@@ -1083,7 +1083,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10831083
"adamw(x)",
10841084
};
10851085

1086-
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
1086+
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
10871087

10881088
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
10891089

0 commit comments

Comments
 (0)