|
22 | 22 | #define WEBGPU_MAX_BUFFERS 32
|
23 | 23 |
|
24 | 24 | #define WEBGPU_MUL_MAT_WG_SIZE 64
|
25 |
| -#define WEBGPU_MUL_MAT_PARAMS_SIZE (7 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts |
26 |
| -#define WEBGPU_CPY_PARAMS_SIZE (3 * sizeof(uint32_t)) // number of elements to copy, alignments |
| 25 | +#define WEBGPU_MUL_MAT_PARAMS_SIZE (13 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts |
| 26 | +#define WEBGPU_CPY_PARAMS_SIZE (15 * sizeof(uint32_t)) // strides and offsets |
27 | 27 | #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
28 | 28 |
|
29 | 29 | /* End Constants */
|
@@ -266,10 +266,26 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
|
266 | 266 | ggml_backend_webgpu_map_buffer(ctx, ctx->cpy_params_host_buf,
|
267 | 267 | wgpu::MapMode::Write, 0, ctx->cpy_params_host_buf.GetSize());
|
268 | 268 | uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange();
|
269 |
| - uint32_t ne = (uint32_t)ggml_nelements(node); // number of elements to copy |
| 269 | + uint32_t ne = (uint32_t)ggml_nelements(node); |
270 | 270 | params[0] = ne;
|
271 | 271 | params[1] = src_misalignment;
|
272 | 272 | params[2] = dst_misalignment;
|
| 273 | + |
| 274 | + // Convert byte-strides to element-strides |
| 275 | + params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type); |
| 276 | + params[4] = (uint32_t)src->nb[1]/ggml_type_size(src->type); |
| 277 | + params[5] = (uint32_t)src->nb[2]/ggml_type_size(src->type); |
| 278 | + params[6] = (uint32_t)src->nb[3]/ggml_type_size(src->type); |
| 279 | + params[7] = (uint32_t)node->nb[0]/ggml_type_size(node->type); |
| 280 | + params[8] = (uint32_t)node->nb[1]/ggml_type_size(node->type); |
| 281 | + params[9] = (uint32_t)node->nb[2]/ggml_type_size(node->type); |
| 282 | + params[10] = (uint32_t)node->nb[3]/ggml_type_size(node->type); |
| 283 | + // Logical shape — same for both tensors even if permuted |
| 284 | + params[11] = (uint32_t)(src->ne[0]); |
| 285 | + params[12] = (uint32_t)(src->ne[1]); |
| 286 | + params[13] = (uint32_t)(src->ne[2]); |
| 287 | + params[14] = (uint32_t)(src->ne[3]); |
| 288 | + |
273 | 289 | ctx->cpy_params_host_buf.Unmap();
|
274 | 290 |
|
275 | 291 | wgpu::BindGroupEntry entries[3];
|
@@ -338,10 +354,18 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
|
338 | 354 | params[0] = (uint32_t)node->ne[1]; // number of rows in result (M)
|
339 | 355 | params[1] = (uint32_t)node->ne[0]; // number of columns in result (N)
|
340 | 356 | params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K)
|
341 |
| - params[3] = (uint32_t)src0->ne[2]; // batch size in dimension 2 |
342 |
| - params[4] = (uint32_t)src0->ne[3]; // batch size in dimension 3 |
343 |
| - params[5] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2 |
344 |
| - params[6] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3 |
| 357 | + |
| 358 | + params[3] = (uint32_t)src0->nb[1]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 1 |
| 359 | + params[4] = (uint32_t)src1->nb[1]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 1 |
| 360 | + params[5] = (uint32_t)src0->nb[2]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 2 |
| 361 | + params[6] = (uint32_t)src1->nb[2]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 2 |
| 362 | + params[7] = (uint32_t)src0->nb[3]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 3 |
| 363 | + params[8] = (uint32_t)src1->nb[3]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 3 |
| 364 | + |
| 365 | + params[9] = (uint32_t)src0->ne[2]; // batch size in dimension 2 |
| 366 | + params[10] = (uint32_t)src0->ne[3]; // batch size in dimension 3 |
| 367 | + params[11] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2 |
| 368 | + params[12] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3 |
345 | 369 |
|
346 | 370 | ctx->mul_mat_params_host_buf.Unmap();
|
347 | 371 |
|
|
0 commit comments