Skip to content

Commit 966aa76

Browse files
committed
Support F16 operations
1 parent aed4e1f commit 966aa76

File tree

1 file changed

+34
-31
lines changed

1 file changed

+34
-31
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6546,29 +6546,29 @@ void ggml_compute_forward_im2col_back_f32(
65466546
}
65476547
}
65486548

6549-
static void ggml_call_mul_mat(
6550-
const ggml_compute_params * params,
6551-
int64_t m, int64_t n, int64_t k,
6552-
void * a, void * b, void * c) {
6553-
6549+
static void ggml_call_mul_mat(ggml_type T, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6550+
void * a, void * b, void * c) {
6551+
const ggml_type_traits * traits = ggml_get_type_traits(T);
65546552
struct ggml_tensor src1 = {};
6553+
src1.type = T;
65556554
src1.ne[0] = k;
65566555
src1.ne[1] = m;
65576556
src1.ne[2] = 1;
65586557
src1.ne[3] = 1;
6559-
src1.nb[0] = sizeof(float);
6560-
src1.nb[1] = k * sizeof(float);
6558+
src1.nb[0] = traits->type_size;
6559+
src1.nb[1] = k * traits->type_size;
65616560
src1.nb[2] = src1.nb[1];
65626561
src1.nb[3] = src1.nb[2];
65636562
src1.data = a;
65646563

65656564
struct ggml_tensor src0 = {};
6565+
src0.type = T;
65666566
src0.ne[0] = k;
65676567
src0.ne[1] = n;
65686568
src0.ne[2] = 1;
65696569
src0.ne[3] = 1;
6570-
src0.nb[0] = sizeof(float);
6571-
src0.nb[1] = k * sizeof(float);
6570+
src0.nb[0] = traits->type_size;
6571+
src0.nb[1] = k * traits->type_size;
65726572
src0.nb[2] = src0.nb[1];
65736573
src0.nb[3] = src0.nb[2];
65746574
src0.data = b;
@@ -6589,17 +6589,18 @@ static void ggml_call_mul_mat(
65896589
ggml_compute_forward_mul_mat(params, &dst);
65906590
}
65916591

6592-
65936592
// ggml_compute_forward_conv_2d
65946593

6595-
static void ggml_compute_forward_conv_2d_f32(
6596-
const ggml_compute_params * params,
6597-
const ggml_tensor * kernel, // [KW, KH, IC, OC] - fp32
6598-
const ggml_tensor * src, // [W, H, C, N]
6599-
ggml_tensor * dst) { // [OW, OH, OC, N]
6594+
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6595+
const ggml_tensor * kernel, // [KW, KH, IC, OC]
6596+
const ggml_tensor * src, // [W, H, C, N]
6597+
ggml_tensor * dst, // [OW, OH, OC, N]
6598+
ggml_type kernel_type) {
66006599

66016600
GGML_ASSERT(ggml_is_contiguous(kernel));
6602-
GGML_ASSERT(kernel->type == GGML_TYPE_F32);
6601+
GGML_ASSERT(kernel->type == kernel_type);
6602+
6603+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
66036604

66046605
const int32_t stride_x = dst->op_params[0];
66056606
const int32_t stride_y = dst->op_params[1];
@@ -6620,20 +6621,20 @@ static void ggml_compute_forward_conv_2d_f32(
66206621
const int64_t dst_h = dst->ne[1];
66216622

66226623
float * src_data = (float*) src->data;
6623-
float * knl_data = (float*) kernel->data;
6624+
void * knl_data = kernel->data;
66246625
float * dst_data = (float*) dst->data;
66256626

66266627
const int64_t knl_n = knl_w * knl_h * c_in;
66276628
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
66286629

6629-
const int64_t space_per_patch = knl_n * sizeof(float) + c_out * sizeof(float);
6630+
const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
66306631
const int64_t batch_size = params->wsize / space_per_patch;
66316632
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
66326633
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
66336634

66346635
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
66356636

6636-
float * tmp = (float *) params->wdata;
6637+
void * tmp = params->wdata;
66376638

66386639
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
66396640

@@ -6653,7 +6654,7 @@ static void ggml_compute_forward_conv_2d_f32(
66536654
const int64_t src_y = p % dst_w;
66546655

66556656
float * src_base = (float *)((char *)src_data + batch_n * src->nb[3]);
6656-
float * dst_row = tmp + (p % patches_per_batch) * knl_n;
6657+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
66576658

66586659
for (int64_t ic = 0; ic < c_in; ++ic) {
66596660
for (int64_t ky = 0; ky < knl_h; ++ky) {
@@ -6663,11 +6664,19 @@ static void ggml_compute_forward_conv_2d_f32(
66636664

66646665
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
66656666

6667+
float src_val;
66666668
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6667-
dst_row[dst_idx] = 0.0f;
6669+
src_val = 0.0f;
66686670
} else {
66696671
float * src_ptr = (float *)((char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6670-
dst_row[dst_idx] = *src_ptr;
6672+
src_val = *src_ptr;
6673+
}
6674+
6675+
char * element_ptr = dst_row + dst_idx * traits->type_size;
6676+
if (kernel_type == GGML_TYPE_F32) {
6677+
*(float *) element_ptr = src_val;
6678+
} else if (kernel_type == GGML_TYPE_F16) {
6679+
*(ggml_fp16_t *) element_ptr = GGML_FP32_TO_FP16(src_val);
66716680
}
66726681
}
66736682
}
@@ -6676,11 +6685,10 @@ static void ggml_compute_forward_conv_2d_f32(
66766685

66776686
ggml_barrier(params->threadpool);
66786687

6679-
float * gemm_output = tmp + patches_per_batch * knl_n;
6688+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
66806689

66816690
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6682-
ggml_call_mul_mat(params, patch_n, c_out, knl_n,
6683-
tmp, knl_data, gemm_output);
6691+
ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
66846692

66856693
ggml_barrier(params->threadpool);
66866694

@@ -6698,7 +6706,6 @@ static void ggml_compute_forward_conv_2d_f32(
66986706

66996707
for (int64_t oc = 0; oc < c_out; ++oc) {
67006708
const float value = gemm_output[i * c_out + oc];
6701-
// Write to WHCN layout: dst[w, h, c, n]
67026709
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
67036710
*dst_ptr = value;
67046711
}
@@ -6713,11 +6720,7 @@ void ggml_compute_forward_conv_2d(
67136720
const ggml_tensor * src0 = dst->src[0];
67146721
const ggml_tensor * src1 = dst->src[1];
67156722

6716-
if (src0->type == GGML_TYPE_F16) {
6717-
GGML_ASSERT(false && "F16 not supported yet");
6718-
} else {
6719-
ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
6720-
}
6723+
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
67216724
}
67226725

67236726
// ggml_compute_forward_conv_transpose_2d

0 commit comments

Comments
 (0)