Skip to content

Commit 0f0543b

Browse files
committed
Implement permuting for mul_mat and cpy
1 parent ecb945e commit 0f0543b

File tree

3 files changed

+82
-29
lines changed

3 files changed

+82
-29
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
#define WEBGPU_MAX_BUFFERS 32
2323

2424
#define WEBGPU_MUL_MAT_WG_SIZE 64
25-
#define WEBGPU_MUL_MAT_PARAMS_SIZE (7 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts
26-
#define WEBGPU_CPY_PARAMS_SIZE (3 * sizeof(uint32_t)) // number of elements to copy, alignments
25+
#define WEBGPU_MUL_MAT_PARAMS_SIZE (13 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts
26+
#define WEBGPU_CPY_PARAMS_SIZE (15 * sizeof(uint32_t)) // strides and offsets
2727
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
2828

2929
/* End Constants */
@@ -266,10 +266,26 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
266266
ggml_backend_webgpu_map_buffer(ctx, ctx->cpy_params_host_buf,
267267
wgpu::MapMode::Write, 0, ctx->cpy_params_host_buf.GetSize());
268268
uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange();
269-
uint32_t ne = (uint32_t)ggml_nelements(node); // number of elements to copy
269+
uint32_t ne = (uint32_t)ggml_nelements(node);
270270
params[0] = ne;
271271
params[1] = src_misalignment;
272272
params[2] = dst_misalignment;
273+
274+
// Convert byte-strides to element-strides
275+
params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type);
276+
params[4] = (uint32_t)src->nb[1]/ggml_type_size(src->type);
277+
params[5] = (uint32_t)src->nb[2]/ggml_type_size(src->type);
278+
params[6] = (uint32_t)src->nb[3]/ggml_type_size(src->type);
279+
params[7] = (uint32_t)node->nb[0]/ggml_type_size(node->type);
280+
params[8] = (uint32_t)node->nb[1]/ggml_type_size(node->type);
281+
params[9] = (uint32_t)node->nb[2]/ggml_type_size(node->type);
282+
params[10] = (uint32_t)node->nb[3]/ggml_type_size(node->type);
283+
// Logical shape — same for both tensors even if permuted
284+
params[11] = (uint32_t)(src->ne[0]);
285+
params[12] = (uint32_t)(src->ne[1]);
286+
params[13] = (uint32_t)(src->ne[2]);
287+
params[14] = (uint32_t)(src->ne[3]);
288+
273289
ctx->cpy_params_host_buf.Unmap();
274290

275291
wgpu::BindGroupEntry entries[3];
@@ -338,10 +354,18 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
338354
params[0] = (uint32_t)node->ne[1]; // number of rows in result (M)
339355
params[1] = (uint32_t)node->ne[0]; // number of columns in result (N)
340356
params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K)
341-
params[3] = (uint32_t)src0->ne[2]; // batch size in dimension 2
342-
params[4] = (uint32_t)src0->ne[3]; // batch size in dimension 3
343-
params[5] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2
344-
params[6] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3
357+
358+
params[3] = (uint32_t)src0->nb[1]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 1
359+
params[4] = (uint32_t)src1->nb[1]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 1
360+
params[5] = (uint32_t)src0->nb[2]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 2
361+
params[6] = (uint32_t)src1->nb[2]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 2
362+
params[7] = (uint32_t)src0->nb[3]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 3
363+
params[8] = (uint32_t)src1->nb[3]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 3
364+
365+
params[9] = (uint32_t)src0->ne[2]; // batch size in dimension 2
366+
params[10] = (uint32_t)src0->ne[3]; // batch size in dimension 3
367+
params[11] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2
368+
params[12] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3
345369

346370
ctx->mul_mat_params_host_buf.Unmap();
347371

ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,54 @@ var<storage, read_write> src: array<f32>;
77
var<storage, read_write> dst: array<f16>;
88

99
struct Params {
10-
ne: u32, // number of elements
11-
src_offset: u32, // src offset in bytes
12-
dst_offset: u32 // dst offset in bytes
10+
ne: u32, // total number of elements
11+
src_offset: u32, // in bytes
12+
dst_offset: u32, // in bytes
13+
14+
// Strides (in elements) — may be permuted
15+
stride_src0: u32,
16+
stride_src1: u32,
17+
stride_src2: u32,
18+
stride_src3: u32,
19+
20+
stride_dst0: u32,
21+
stride_dst1: u32,
22+
stride_dst2: u32,
23+
stride_dst3: u32,
24+
25+
// Logical shape (same for both tensors)
26+
ne0: u32,
27+
ne1: u32,
28+
ne2: u32,
29+
ne3: u32,
1330
};
1431

1532
@group(0) @binding(2)
1633
var<uniform> params: Params;
1734

1835
override wg_size: u32;
19-
const elems_per_thread: u32 = 4;
20-
2136
@compute @workgroup_size(wg_size)
2237
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
23-
let idx = gid.x * elems_per_thread;
24-
// chunked loop
25-
for (var j: u32 = 0u; j < elems_per_thread; j = j + 1u) {
26-
let i = idx + j;
27-
if (i < params.ne) {
28-
// Convert f32 to f16
29-
dst[dst_offset/2 + i] = f16(src[src_offset/4 + i]);
30-
}
38+
if (gid.x >= params.ne) {
39+
return;
3140
}
41+
42+
var i = gid.x;
43+
44+
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
45+
i = i % (params.ne2 * params.ne1 * params.ne0);
46+
47+
let i2 = i / (params.ne1 * params.ne0);
48+
i = i % (params.ne1 * params.ne0);
49+
50+
let i1 = i / params.ne0;
51+
let i0 = i % params.ne0;
52+
53+
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
54+
i2 * params.stride_src2 + i3 * params.stride_src3;
55+
56+
let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
57+
i2 * params.stride_dst2 + i3 * params.stride_dst3;
58+
59+
dst[params.dst_offset / 2 + dst_idx] = f16(src[params.src_offset / 4 + src_idx]);
3260
}

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@ struct MulMatParams {
22
m: u32,
33
n: u32,
44
k: u32,
5+
stride_01: u32,
6+
stride_11: u32,
7+
stride_02: u32,
8+
stride_12: u32,
9+
stride_03: u32,
10+
stride_13: u32,
11+
512
bs02: u32,
613
bs03: u32,
714
broadcast2: u32,
@@ -21,12 +28,6 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
2128
return;
2229
}
2330

24-
let src02_stride = params.n * params.k;
25-
let src03_stride = src02_stride * params.bs02;
26-
27-
let src12_stride = params.m * params.k;
28-
let src13_stride = src12_stride * params.bs02 * params.broadcast2;
29-
3031
let dst2_stride = params.m * params.n;
3132
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
3233

@@ -37,7 +38,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
3738

3839
let dst2_idx = dst3_rem / dst2_stride;
3940
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
40-
let src12_idx = dst2_idx;
41+
let src12_idx = dst2_idx; // src1 is not broadcast
4142

4243
let dst2_rem = dst3_rem % dst2_stride;
4344

@@ -46,8 +47,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
4647

4748
var sum = 0.0;
4849
for (var i: u32 = 0u; i < params.k; i = i + 1u) {
49-
let src0_idx = src03_idx * src03_stride + src02_idx * src02_stride + col * params.k + i;
50-
let src1_idx = src13_idx * src13_stride + src12_idx * src12_stride + row * params.k + i;
50+
let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i;
51+
let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i;
5152
sum = sum + src0[src0_idx] * src1[src1_idx];
5253
}
5354
dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;

0 commit comments

Comments
 (0)