Skip to content

Commit 808f526

Browse files
committed
Review: further formatting fixes, add assert and use CPU version of fp32->fp16
1 parent 29d77dc commit 808f526

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6546,11 +6546,11 @@ void ggml_compute_forward_im2col_back_f32(
65466546
}
65476547
}
65486548

6549-
static void ggml_call_mul_mat(ggml_type T, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6550-
void * a, void * b, void * c) {
6551-
const ggml_type_traits * traits = ggml_get_type_traits(T);
6549+
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6550+
void * a, void * b, float * c) {
6551+
const ggml_type_traits * traits = ggml_get_type_traits(type);
65526552
struct ggml_tensor src1 = {};
6553-
src1.type = T;
6553+
src1.type = type;
65546554
src1.ne[0] = k;
65556555
src1.ne[1] = m;
65566556
src1.ne[2] = 1;
@@ -6562,7 +6562,7 @@ static void ggml_call_mul_mat(ggml_type T, const ggml_compute_params * params, i
65626562
src1.data = a;
65636563

65646564
struct ggml_tensor src0 = {};
6565-
src0.type = T;
6565+
src0.type = type;
65666566
src0.ne[0] = k;
65676567
src0.ne[1] = n;
65686568
src0.ne[2] = 1;
@@ -6598,6 +6598,7 @@ static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params
65986598
ggml_type kernel_type) {
65996599

66006600
GGML_ASSERT(ggml_is_contiguous(kernel));
6601+
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
66016602
GGML_ASSERT(kernel->type == kernel_type);
66026603

66036604
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
@@ -6620,9 +6621,9 @@ static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params
66206621
const int64_t dst_w = dst->ne[0];
66216622
const int64_t dst_h = dst->ne[1];
66226623

6623-
float * src_data = (float*) src->data;
6624-
void * knl_data = kernel->data;
6625-
float * dst_data = (float*) dst->data;
6624+
const float * src_data = (float *) src->data;
6625+
void * knl_data = kernel->data;
6626+
float * dst_data = (float *) dst->data;
66266627

66276628
const int64_t knl_n = knl_w * knl_h * c_in;
66286629
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
@@ -6653,8 +6654,8 @@ static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params
66536654
const int64_t src_x = (p / dst_w) % dst_h;
66546655
const int64_t src_y = p % dst_w;
66556656

6656-
float * src_base = (float *)((char *)src_data + batch_n * src->nb[3]);
6657-
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6657+
const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6658+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
66586659

66596660
for (int64_t ic = 0; ic < c_in; ++ic) {
66606661
for (int64_t ky = 0; ky < knl_h; ++ky) {
@@ -6668,15 +6669,15 @@ static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params
66686669
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
66696670
src_val = 0.0f;
66706671
} else {
6671-
float * src_ptr = (float *)((char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6672-
src_val = *src_ptr;
6672+
const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6673+
src_val = *src_ptr;
66736674
}
66746675

66756676
char * element_ptr = dst_row + dst_idx * traits->type_size;
66766677
if (kernel_type == GGML_TYPE_F32) {
66776678
*(float *) element_ptr = src_val;
66786679
} else if (kernel_type == GGML_TYPE_F16) {
6679-
*(ggml_fp16_t *) element_ptr = GGML_FP32_TO_FP16(src_val);
6680+
*(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
66806681
}
66816682
}
66826683
}

ggml/src/ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
987987
"GLU",
988988
};
989989

990-
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
990+
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
991991

992992
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
993993
"none",
@@ -1087,7 +1087,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10871087
"glu(x)",
10881088
};
10891089

1090-
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
1090+
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
10911091

10921092
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
10931093

0 commit comments

Comments
 (0)