Skip to content

Commit 1c396a2

Browse files
committed
Work on passing ci, implement 4d tensor multiplication
1 parent daa58e2 commit 1c396a2

File tree

5 files changed

+97
-43
lines changed

5 files changed

+97
-43
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,8 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
269269
| [Vulkan](docs/build.md#vulkan) | GPU |
270270
| [CANN](docs/build.md#cann) | Ascend NPU |
271271
| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU |
272+
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
273+
272274
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
273275

274276
## Obtaining and quantizing models

ci/run.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
# # with VULKAN support
1717
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
1818
#
19+
# # with WebGPU support
20+
# GG_BUILD_WEBGPU=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
21+
#
1922
# # with MUSA support
2023
# GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
2124
#
@@ -81,6 +84,10 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then
8184
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_VULKAN=1"
8285
fi
8386

87+
if [ ! -z ${GG_BUILD_WEBGPU} ]; then
88+
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1"
89+
fi
90+
8491
if [ ! -z ${GG_BUILD_MUSA} ]; then
8592
# Use qy1 by default (MTT S80)
8693
MUSA_ARCH=${MUSA_ARCH:-21}

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,16 @@
1616
#define WEBGPU_LOG_DEBUG(msg) ((void) 0)
1717
#endif // GGML_WEBGPU_DEBUG
1818

19+
/* Constants */
20+
1921
// TODO: find a better way to get the memory available
2022
#define WEBGPU_MAX_BUFFERS 32
2123

24+
#define WEBGPU_MUL_MAT_WG_SIZE 64
25+
#define WEBGPU_MUL_MAT_PARAMS_SIZE (7 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts
26+
27+
/* End Constants */
28+
2229
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
2330
static void * const webgpu_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
2431

@@ -138,18 +145,16 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buff
138145
);
139146
}
140147

141-
static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer buf, uint8_t value, size_t offset, size_t size) {
148+
static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer buf, uint32_t value, size_t offset, size_t size) {
142149
wgpu::Device device = ctx->device;
143150

144151
// map the host parameters buffer
145152
ggml_backend_webgpu_map_buffer(ctx, ctx->memset_params_host_buf, wgpu::MapMode::Write, 0, ctx->memset_params_host_buf.GetSize());
146153
uint32_t * params = (uint32_t *) ctx->memset_params_host_buf.GetMappedRange();
147154

148-
// This is a trick to set all bytes of a u32 to the same 1 byte value.
149-
uint32_t val32 = (uint32_t)value * 0x01010101;
150155
params[0] = (uint32_t)offset;
151156
params[1] = (uint32_t)size;
152-
params[2] = val32;
157+
params[2] = value;
153158
ctx->memset_params_host_buf.Unmap();
154159

155160
wgpu::BindGroupEntry entries[2];
@@ -191,7 +196,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer b
191196
/** GGML Backend Interface */
192197

193198
static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
194-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_name()");
195199
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
196200
return ctx->name.c_str();
197201
}
@@ -201,6 +205,7 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
201205
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
202206

203207
// TODO: cleanup
208+
GGML_UNUSED(ctx);
204209
}
205210

206211
// Returns true if node has enqueued work into the queue, false otherwise
@@ -244,6 +249,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
244249
params[0] = (uint32_t)node->ne[1]; // number of rows in result (M)
245250
params[1] = (uint32_t)node->ne[0]; // number of columns in result (N)
246251
params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K)
252+
params[3] = (uint32_t)src0->ne[2]; // batch size in dimension 2
253+
params[4] = (uint32_t)src0->ne[3]; // batch size in dimension 3
254+
params[5] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2
255+
params[6] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3
256+
247257
ctx->mul_mat_params_host_buf.Unmap();
248258

249259
wgpu::BindGroupEntry entries[4];
@@ -282,7 +292,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
282292
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
283293
pass.SetPipeline(ctx->mul_mat_pipeline);
284294
pass.SetBindGroup(0, bind_group);
285-
pass.DispatchWorkgroups(node->ne[0] * node->ne[1]);
295+
pass.DispatchWorkgroups((node->ne[0] * node->ne[1] * node->ne[2] * node->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE);
286296
pass.End();
287297
wgpu::CommandBuffer commands = encoder.Finish();
288298

@@ -352,7 +362,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
352362

353363
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
354364
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
355-
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, total_offset, size);
365+
// This is a trick to set all bytes of a u32 to the same 1 byte value.
366+
uint32_t val32 = (uint32_t)value * 0x01010101;
367+
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
356368
}
357369

358370
static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
@@ -363,10 +375,21 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
363375
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
364376

365377
// TODO: wait on this?
366-
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, size);
378+
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size/4)*4);
379+
380+
if (size % 4 != 0) {
381+
// If size is not a multiple of 4, we need to memset the remaining bytes
382+
size_t remaining_size = size % 4;
383+
// pack the remaining bytes into a uint32_t
384+
uint32_t val32 = 0;
385+
for (size_t i = 0; i < remaining_size; i++) {
386+
((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i];
387+
}
388+
// memset the remaining bytes
389+
ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
390+
}
367391
}
368392

369-
// TODO: we need a staging buffer for this, since WebGPU does not allow reading from storage buffers directly.
370393
static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
371394
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
372395

@@ -376,33 +399,39 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
376399

377400
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
378401

402+
size_t final_size = size;
403+
if (size % 4 != 0) {
404+
// If size is not a multiple of 4, we need to round it up to the next multiple of 4
405+
final_size = size + (4 - (size % 4));
406+
}
407+
379408
if (webgpu_ctx->get_tensor_staging_buf == nullptr ||
380-
webgpu_ctx->get_tensor_staging_buf.GetSize() < size) {
409+
webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
381410
// Create a new staging buffer if it doesn't exist or is too small
382411
if (webgpu_ctx->get_tensor_staging_buf) {
383412
webgpu_ctx->get_tensor_staging_buf.Destroy();
384413
}
385-
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, size,
414+
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
386415
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
387416
}
388417

389418
// Copy the data from the buffer to the staging buffer
390419
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
391-
encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, size);
420+
encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
392421
wgpu::CommandBuffer commands = encoder.Finish();
393422
// Submit the command buffer to the queue
394423
webgpu_ctx->queue.Submit(1, &commands);
395424

396425
// Map the staging buffer to read the data
397-
ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, size);
398-
const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange();
426+
ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
427+
// Must specify size here since the staging buffer might be larger than the tensor size
428+
const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
399429

400430
// Copy the data from the mapped range to the output buffer
401431
std::memcpy(data, mapped_range, size);
402432
webgpu_ctx->get_tensor_staging_buf.Unmap();
403433
}
404434

405-
// TODO
406435
static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
407436
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << value << ")");
408437

@@ -427,7 +456,6 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
427456
/* GGML Backend Buffer Type Interface */
428457

429458
static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
430-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_get_name()");
431459
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
432460
return ctx->device_name.c_str();
433461
}
@@ -446,14 +474,12 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
446474
}
447475

448476
static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
449-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_get_alignment()");
450477
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
451478
return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
452479
}
453480

454481
// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
455482
static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
456-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_get_max_size()");
457483
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
458484
return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
459485
}
@@ -473,16 +499,13 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_
473499
}
474500

475501
static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
476-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_get_memory()");
477-
478502
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
479503
// TODO: what do we actually want to return here?
480504
*free = ctx->webgpu_ctx->limits.maxBufferSize * WEBGPU_MAX_BUFFERS;
481505
*total = ctx->webgpu_ctx->limits.maxBufferSize * WEBGPU_MAX_BUFFERS;
482506
}
483507

484508
static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
485-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_get_type()");
486509
GGML_UNUSED(dev);
487510
return GGML_BACKEND_DEVICE_TYPE_GPU;
488511
}
@@ -526,11 +549,10 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
526549

527550
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) {
528551
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
529-
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_dev_buf,
530-
3 * sizeof(uint32_t), // 3 parameters: M, N, K
552+
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_dev_buf, WEBGPU_MUL_MAT_PARAMS_SIZE,
531553
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst);
532-
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_host_buf,
533-
3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc);
554+
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_host_buf,WEBGPU_MUL_MAT_PARAMS_SIZE,
555+
wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc);
534556
}
535557

536558
// TODO: Does this need to be thread safe? Is it only called once?
@@ -617,13 +639,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
617639
// what should we support first?
618640
switch (op->op) {
619641
case GGML_OP_NONE:
620-
case GGML_OP_RESHAPE:
621-
case GGML_OP_VIEW:
622-
case GGML_OP_PERMUTE:
623-
case GGML_OP_TRANSPOSE:
624-
case GGML_OP_MUL_MAT:
625642
return true;
626-
643+
case GGML_OP_MUL_MAT:
644+
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
627645
default:
628646
return false;
629647
}
@@ -652,13 +670,11 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
652670
/* GGML Backend Registration Interface */
653671

654672
static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
655-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg_get_name()");
656673
ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
657674
return ctx->name;
658675
}
659676

660677
static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
661-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg_get_device_count()");
662678
ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
663679
return ctx->device_count;
664680
}

ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ var<storage, read_write> output_buffer: array<u32>;
55
struct Params {
66
offset: u32, // in bytes
77
size: u32, // in bytes
8-
value: u32, // four identical values
8+
value: u32, // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations)
99
};
1010

1111
@group(0) @binding(1)
Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,54 @@
11
struct MulMatParams {
22
m: u32,
33
n: u32,
4-
k: u32
4+
k: u32,
5+
bs02: u32,
6+
bs03: u32,
7+
broadcast2: u32,
8+
broadcast3: u32
59
};
610

7-
@group(0) @binding(0) var<storage, read> src0: array<f32>;
8-
@group(0) @binding(1) var<storage, read> src1: array<f32>;
9-
@group(0) @binding(2) var<storage, read_write> dst: array<f32>;
11+
@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // N rows, K columns
12+
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // M rows, K columns
13+
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
1014

1115
@group(0) @binding(3) var<uniform> params: MulMatParams;
1216

1317
@compute @workgroup_size(64)
1418
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
15-
if (global_id.x >= params.m * params.n) {
19+
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
20+
if (global_id.x >= total) {
1621
return;
1722
}
18-
let row = global_id.x / params.n;
19-
let col = global_id.x % params.n;
23+
24+
let src02_stride = params.n * params.k;
25+
let src03_stride = src02_stride * params.bs02;
26+
27+
let src12_stride = params.m * params.k;
28+
let src13_stride = src12_stride * params.bs02 * params.broadcast2;
29+
30+
let dst2_stride = params.m * params.n;
31+
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
32+
33+
let dst3_idx = global_id.x / dst3_stride;
34+
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
35+
let src13_idx = dst3_idx; // src1 is not broadcast
36+
let dst3_rem = global_id.x % dst3_stride;
37+
38+
let dst2_idx = dst3_rem / dst2_stride;
39+
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
40+
let src12_idx = dst2_idx;
41+
42+
let dst2_rem = dst3_rem % dst2_stride;
43+
44+
let row = dst2_rem / params.n; // output row
45+
let col = dst2_rem % params.n; // output column
46+
2047
var sum = 0.0;
2148
for (var i: u32 = 0u; i < params.k; i = i + 1u) {
22-
sum = sum + src0[col * params.k + i] * src1[row * params.k + i];
49+
let src0_idx = src03_idx * src03_stride + src02_idx * src02_stride + col * params.k + i;
50+
let src1_idx = src13_idx * src13_stride + src12_idx * src12_stride + row * params.k + i;
51+
sum = sum + src0[src0_idx] * src1[src1_idx];
2352
}
24-
dst[row * params.n + col] = sum;
53+
dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
2554
}

0 commit comments

Comments
 (0)