Skip to content

Commit 0b81e21

Browse files
authored
[SYCL][E2E][JM] Add GELU kernel to the functional suite (#18127)
1 parent 48518f5 commit 0b81e21

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

sycl/test-e2e/Matrix/Inputs/common.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,9 @@ template <typename T, layout Layout> constexpr int vnni_factor() {
261261
static_assert(sizeof(T) <= 4 && "Unsupported type in vnni_factor().");
262262
return 4 / sizeof(T);
263263
}
264+
265+
inline float gelu(float val) {
266+
return val *
267+
(0.5f + 0.5f * sycl::tanh(val * (0.7978845608028654f +
268+
0.035677408136300125f * val * val)));
269+
}

sycl/test-e2e/Matrix/Inputs/joint_matrix_bf16_fill_k_cache_impl.hpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i
329329
#else // MANUAL_UNROLL
330330
for (unsigned int n = 0; n < NCache1 / TN; n++) {
331331
#endif // MANUAL_UNROLL
332+
#ifdef GELU
333+
joint_matrix_apply(sg, tC[m][n], [=](float &x) { x = gelu(x); });
334+
#endif // GELU
332335
#ifdef OOB
333336
ext::intel::experimental::matrix::joint_matrix_store_checked(
334337
sg, tC[m][n], pC, colsB, layout::row_major, rowsA, colsB,
@@ -387,9 +390,13 @@ void test(size_t matrix_size_input) {
387390
matrix_rand<T>(matrix_size, matrix_size, A, T(1));
388391
matrix_rand<T>(matrix_size, matrix_size, B, T(1));
389392

390-
matrix_multiply_ref<T, T, TResult, 1>(A, B, refC, matrix_size, matrix_size,
391-
matrix_size);
392-
393+
matrix_multiply_ref<T, T, TResult, 1>(
394+
A, B, refC, matrix_size, matrix_size, matrix_size
395+
#ifdef GELU
396+
,
397+
false, false, false, [](float &x) { x = gelu(x); }
398+
#endif // GELU
399+
);
393400
#ifdef VNNI
394401
T *vnniB = malloc_shared<T>(matrix_size * matrix_size, q);
395402
matrix_vnni<T>(matrix_size, matrix_size, B, vnniB, vnniFactor);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//==------------- joint_matrix_gelu.cpp - DPC++ joint_matrix---------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// UNSUPPORTED: target-nvidia, target-amd
9+
// UNSUPPORTED-INTENDED: aspect-ext_intel_matrix isn't currently supported for
10+
// other triples
11+
// XFAIL: run-mode
12+
// XFAIL-TRACKER: CMPLRLLVM-66371
13+
// REQUIRES: aspect-ext_intel_matrix
14+
15+
// RUN: %{build} -o %t.out -DGELU %fp-model-precise
16+
// RUN: %{run} %t.out
17+
18+
// -ffp-model=precise is added to not depend on compiler defaults.
19+
20+
#include "common.hpp"
21+
#include "joint_matrix_bf16_fill_k_cache_impl.hpp"

0 commit comments

Comments
 (0)