Skip to content

Commit 97dba84

Browse files
authored
[SYCL Spec][Joint Matrix] Add a new overload for joint_matrix_apply to be able to return result into a different matrix (#13153)
Currently, CUDA code that use this pattern: for (int i = 0; i < c_frag.num_elements; i++) { c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i]; } cannot be migrated to SYCL joint matrix. This added overload addresses this limitation.
1 parent 2559d65 commit 97dba84

File tree

1 file changed

+41
-3
lines changed

1 file changed

+41
-3
lines changed

sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,14 @@ of the link:sycl_ext_intel_matrix.asciidoc[sycl_ext_intel_matrix]
401401

402402
Besides the `Group` and the `joint_matrix` arguments,
403403
`joint_matrix_apply` takes a C++ Callable object which is invoked once
404-
for each element of the matrix. This callable object must be invocable
405-
with a single parameter of type `T&`. Commonly, applications pass a
406-
lambda expression.
404+
for each element of the matrix. There are two cases: (1) one matrix is
405+
passed, (2) two matrices are passed.
406+
407+
===== Unary Operation
408+
In this case, `joint_matrix_apply` takes one `joint_matrix`
409+
argument. The callable object must be invocable with a single
410+
parameter of type `T&`. Commonly, applications pass a lambda
411+
expression.
407412

408413
```c++
409414
namespace sycl::ext::oneapi::experimental::matrix {
@@ -427,6 +432,39 @@ joint_matrix_apply(sg, C, [=](T &x) {
427432
});
428433
```
429434

435+
===== Binary Operation
436+
In this case, `joint_matrix_apply` takes two `joint_matrix` arguments:
437+
`jm0` and `jm1` that have the same `use`, number of rows, number of
438+
columns, and `layout`. `jm0` and `jm1` can be read-only, write-only,
439+
or read and write arguments. The callable object must be invocable
440+
with two parameters `x` and `y` of types `T0&` amd `T1&`, where `x` is
441+
an element from `jm0` and `y` is an element from `jm1`. Moreover, `x`
442+
and `y` are guaranteed to have identical coordinates in their
443+
respective matrices. Commonly, applications pass a lambda expression.
444+
445+
```c++
446+
namespace sycl::ext::oneapi::experimental::matrix {
447+
448+
template<typename Group, typename T0, typename T1, use Use,
449+
size_t Rows, size_t Cols, layout Layout, typename F>
450+
void joint_matrix_apply(Group g,
451+
joint_matrix<Group, T0, Use, Rows, Cols, Layout>& jm0,
452+
joint_matrix<Group, T1, Use, Rows, Cols, Layout>& jm1,
453+
F&& func);
454+
455+
} // namespace sycl::ext::oneapi::experimental::matrix
456+
```
457+
458+
In the following example, every element `x` of the matrix `C` is
459+
multiplied by `alpha`. The result is returned into the element `y` of
460+
the matrix `D`.
461+
462+
```c++
463+
joint_matrix_apply(sg, C, D, [=](const T &x, T &y) {
464+
y = x * alpha;
465+
});
466+
```
467+
430468
==== Prefetch
431469

432470
```c++

0 commit comments

Comments
 (0)