Skip to content

Commit 5c58196

Browse files
authored
tighten constraints again
1 parent 1acd121 commit 5c58196

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3211,8 +3211,8 @@ static void ggml_compute_forward_reglu_f32(
32113211
const int nc = src0->ne[0] / 2;
32123212
const int nr = ggml_nrows(src0);
32133213

3214-
GGML_ASSERT(dst->ne[0] >= nc);
3215-
GGML_ASSERT(ggml_nrows(dst) >= nr);
3214+
GGML_ASSERT(dst->ne[0] == nc);
3215+
GGML_ASSERT(ggml_nrows(dst) == nr);
32163216

32173217
// rows per thread
32183218
const int dr = (nr + nth - 1)/nth;
@@ -3252,8 +3252,8 @@ static void ggml_compute_forward_reglu_f16(
32523252
const int nc = src0->ne[0] / 2;
32533253
const int nr = ggml_nrows(src0);
32543254

3255-
GGML_ASSERT(dst->ne[0] >= nc);
3256-
GGML_ASSERT(ggml_nrows(dst) >= nr);
3255+
GGML_ASSERT(dst->ne[0] == nc);
3256+
GGML_ASSERT(ggml_nrows(dst) == nr);
32573257

32583258
// rows per thread
32593259
const int dr = (nr + nth - 1)/nth;
@@ -3318,8 +3318,8 @@ static void ggml_compute_forward_geglu_f32(
33183318
const int nc = src0->ne[0] / 2;
33193319
const int nr = ggml_nrows(src0);
33203320

3321-
GGML_ASSERT(dst->ne[0] >= nc);
3322-
GGML_ASSERT(ggml_nrows(dst) >= nr);
3321+
GGML_ASSERT(dst->ne[0] == nc);
3322+
GGML_ASSERT(ggml_nrows(dst) == nr);
33233323

33243324
// rows per thread
33253325
const int dr = (nr + nth - 1)/nth;
@@ -3359,8 +3359,8 @@ static void ggml_compute_forward_geglu_f16(
33593359
const int nc = src0->ne[0] / 2;
33603360
const int nr = ggml_nrows(src0);
33613361

3362-
GGML_ASSERT(dst->ne[0] >= nc);
3363-
GGML_ASSERT(ggml_nrows(dst) >= nr);
3362+
GGML_ASSERT(dst->ne[0] == nc);
3363+
GGML_ASSERT(ggml_nrows(dst) == nr);
33643364

33653365
// rows per thread
33663366
const int dr = (nr + nth - 1)/nth;
@@ -3425,8 +3425,8 @@ static void ggml_compute_forward_swiglu_f32(
34253425
const int nc = src0->ne[0] / 2;
34263426
const int nr = ggml_nrows(src0);
34273427

3428-
GGML_ASSERT(dst->ne[0] >= nc);
3429-
GGML_ASSERT(ggml_nrows(dst) >= nr);
3428+
GGML_ASSERT(dst->ne[0] == nc);
3429+
GGML_ASSERT(ggml_nrows(dst) == nr);
34303430

34313431
// rows per thread
34323432
const int dr = (nr + nth - 1)/nth;
@@ -3466,8 +3466,8 @@ static void ggml_compute_forward_swiglu_f16(
34663466
const int nc = src0->ne[0] / 2;
34673467
const int nr = ggml_nrows(src0);
34683468

3469-
GGML_ASSERT(dst->ne[0] >= nc);
3470-
GGML_ASSERT(ggml_nrows(dst) >= nr);
3469+
GGML_ASSERT(dst->ne[0] == nc);
3470+
GGML_ASSERT(ggml_nrows(dst) == nr);
34713471

34723472
// rows per thread
34733473
const int dr = (nr + nth - 1)/nth;

ggml/src/ggml-cuda/unary.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst
230230
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
231231
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
232232
GGML_ASSERT(src0->type == dst->type);
233-
GGML_ASSERT(dst->ne[0] >= nc);
234-
GGML_ASSERT(ggml_nrows(dst) >= ggml_nrows(src0));
233+
GGML_ASSERT(dst->ne[0] == nc);
234+
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
235235

236236
if (src0->type == GGML_TYPE_F16) {
237237
unary_gated_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(half), stream);

0 commit comments

Comments
 (0)