Skip to content

Commit d036f10

Browse files
committed
Basic mat mul working
1 parent 3d92436 commit d036f10

File tree

3 files changed

+132
-9
lines changed

3 files changed

+132
-9
lines changed

ggml/include/ggml-webgpu.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ extern "C" {
99

1010
#define GGML_WEBGPU_NAME "WebGPU"
1111

12+
// Needed for examples in ggml
13+
GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void);
14+
1215
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void);
1316

1417
#ifdef __cplusplus

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 104 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,15 @@ struct webgpu_context_struct {
4141
wgpu::Queue queue;
4242
wgpu::Limits limits;
4343

44-
// memset pipeline and parameter buffers
44+
// pipelines and parameter buffers
45+
// TODO: reuse params buffers for different pipelines when possible
4546
wgpu::ComputePipeline memset_pipeline;
4647
wgpu::Buffer memset_params_dev_buf;
4748
wgpu::Buffer memset_params_host_buf;
49+
wgpu::ComputePipeline mul_mat_pipeline;
50+
wgpu::Buffer mul_mat_params_dev_buf;
51+
wgpu::Buffer mul_mat_params_host_buf;
52+
4853
size_t memset_elems_per_thread;
4954

5055
// Staging buffer for reading data from the GPU
@@ -87,7 +92,7 @@ struct ggml_backend_webgpu_buffer_context {
8792

8893
/* WebGPU object initializations */
8994

90-
static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const std::vector<wgpu::ConstantEntry> &constants = {}) {
95+
static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const char * label, const std::vector<wgpu::ConstantEntry> &constants = {}) {
9196
WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
9297
wgpu::ShaderSourceWGSL shader_source;
9398
shader_source.code = shader_code;
@@ -96,6 +101,7 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipel
96101
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
97102

98103
wgpu::ComputePipelineDescriptor pipeline_desc;
104+
pipeline_desc.label = label;
99105
pipeline_desc.compute.module = shader_module;
100106
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
101107
pipeline_desc.layout = nullptr; // Guessing that nullptr means auto layout
@@ -121,7 +127,7 @@ static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer
121127

122128
/** WebGPU Actions */
123129

124-
static void * ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) {
130+
static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) {
125131
ctx->instance.WaitAny(buffer.MapAsync(
126132
mode, offset, size, wgpu::CallbackMode::WaitAnyOnly,
127133
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
@@ -131,15 +137,14 @@ static void * ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer bu
131137
}),
132138
UINT64_MAX
133139
);
134-
return buffer.GetMappedRange();
135140
}
136141

137142
static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer buf, uint8_t value, size_t offset, size_t size) {
138143
wgpu::Device device = ctx->device;
139144

140145
// map the host parameters buffer
141-
uint32_t * params = (uint32_t *)ggml_backend_webgpu_map_buffer(ctx, ctx->memset_params_host_buf,
142-
wgpu::MapMode::Write, 0, ctx->memset_params_host_buf.GetSize());
146+
ggml_backend_webgpu_map_buffer(ctx, ctx->memset_params_host_buf, wgpu::MapMode::Write, 0, ctx->memset_params_host_buf.GetSize());
147+
uint32_t * params = (uint32_t *) ctx->memset_params_host_buf.GetMappedRange();
143148

144149
// This is a trick to set all bytes of a u32 to the same 1 byte value.
145150
uint32_t val32 = (uint32_t)value * 0x01010101;
@@ -207,6 +212,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
207212

208213
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
209214

215+
210216
switch (node->op) {
211217
// no-op
212218
case GGML_OP_NONE:
@@ -216,6 +222,76 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
216222
case GGML_OP_PERMUTE:
217223
case GGML_OP_TRANSPOSE:
218224
return false;
225+
226+
// basic matrix multiplication for now, 2d tensors only
227+
case GGML_OP_MUL_MAT: {
228+
const ggml_tensor * src0 = node->src[0];
229+
ggml_backend_webgpu_buffer_context * src0_ctx = (ggml_backend_webgpu_buffer_context *) src0->buffer->context;
230+
size_t src0_offset = webgpu_tensor_offset(src0) + src0->view_offs;
231+
const ggml_tensor * src1 = node->src[1];
232+
ggml_backend_webgpu_buffer_context * src1_ctx = (ggml_backend_webgpu_buffer_context *) src1->buffer->context;
233+
size_t src1_offset = webgpu_tensor_offset(src1) + src1->view_offs;
234+
ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
235+
236+
size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
237+
238+
wgpu::Device device = ctx->device;
239+
240+
// map the host parameters buffer
241+
ggml_backend_webgpu_map_buffer(ctx, ctx->mul_mat_params_host_buf,
242+
wgpu::MapMode::Write, 0, ctx->mul_mat_params_host_buf.GetSize());
243+
uint32_t * params = (uint32_t *) ctx->mul_mat_params_host_buf.GetMappedRange();
244+
245+
params[0] = (uint32_t)node->ne[1]; // number of rows in result (M)
246+
params[1] = (uint32_t)node->ne[0]; // number of columns in result (N)
247+
params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K)
248+
ctx->mul_mat_params_host_buf.Unmap();
249+
250+
wgpu::BindGroupEntry entries[4];
251+
entries[0].binding = 0; // binding for the buffer to memset
252+
entries[0].buffer = src0_ctx->buffer;
253+
entries[0].offset = src0_offset;
254+
entries[0].size = ggml_nbytes(src0);
255+
256+
entries[1].binding = 1; // binding for the buffer to memset
257+
entries[1].buffer = src1_ctx->buffer;
258+
entries[1].offset = src1_offset;
259+
entries[1].size = ggml_nbytes(src1);
260+
261+
entries[2].binding = 2; // binding for the buffer to memset
262+
entries[2].buffer = dst_ctx->buffer;
263+
entries[2].offset = dst_offset;
264+
entries[2].size = ggml_nbytes(node);
265+
266+
entries[3].binding = 3; // binding for the parameters
267+
entries[3].buffer = ctx->mul_mat_params_dev_buf;
268+
entries[3].offset = 0;
269+
entries[3].size = ctx->mul_mat_params_dev_buf.GetSize();
270+
271+
wgpu::BindGroupDescriptor bind_group_desc;
272+
bind_group_desc.layout = ctx->mul_mat_pipeline.GetBindGroupLayout(0);
273+
bind_group_desc.entryCount = 4;
274+
bind_group_desc.entries = entries;
275+
wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
276+
277+
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
278+
encoder.CopyBufferToBuffer(
279+
ctx->mul_mat_params_host_buf, 0,
280+
ctx->mul_mat_params_dev_buf, 0,
281+
ctx->mul_mat_params_dev_buf.GetSize()
282+
);
283+
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
284+
pass.SetPipeline(ctx->mul_mat_pipeline);
285+
pass.SetBindGroup(0, bind_group);
286+
pass.DispatchWorkgroups(node->ne[0] * node->ne[1]);
287+
pass.End();
288+
wgpu::CommandBuffer commands = encoder.Finish();
289+
290+
// TODO, don't submit here, batch submissions
291+
ctx->queue.Submit(1, &commands);
292+
return true;
293+
}
294+
219295
default:
220296
return false;
221297
}
@@ -230,6 +306,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
230306
for (int i = 0; i < cgraph->n_nodes; i++) {
231307
ggml_webgpu_encode_node(ctx, cgraph->nodes[i]);
232308
}
309+
310+
return GGML_STATUS_SUCCESS;
233311
}
234312

235313
static ggml_backend_i ggml_backend_webgpu_i = {
@@ -317,8 +395,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
317395
webgpu_ctx->queue.Submit(1, &commands);
318396

319397
// Map the staging buffer to read the data
320-
const void * mapped_range = ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf,
321-
wgpu::MapMode::Read, 0, size);
398+
ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, size);
399+
const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange();
322400

323401
// Copy the data from the mapped range to the output buffer
324402
std::memcpy(data, mapped_range, size);
@@ -439,14 +517,23 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
439517
constants[0].value = max_wg_size;
440518
constants[1].key = "elems_per_thread";
441519
constants[1].value = webgpu_ctx->memset_elems_per_thread;
442-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, constants);
520+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants);
443521
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_dev_buf,
444522
3 * sizeof(uint32_t), // 3 parameters: buffer size, offset, value
445523
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst);
446524
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_host_buf,
447525
3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc);
448526
}
449527

528+
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) {
529+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
530+
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_dev_buf,
531+
3 * sizeof(uint32_t), // 3 parameters: M, N, K
532+
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst);
533+
ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_host_buf,
534+
3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc);
535+
}
536+
450537
// TODO: Does this need to be thread safe? Is it only called once?
451538
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
452539
GGML_UNUSED(params);
@@ -485,6 +572,7 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
485572
webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
486573

487574
ggml_webgpu_init_memset_pipeline(webgpu_ctx);
575+
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
488576

489577
static ggml_backend_webgpu_context backend_ctx;
490578
backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
@@ -534,6 +622,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
534622
case GGML_OP_VIEW:
535623
case GGML_OP_PERMUTE:
536624
case GGML_OP_TRANSPOSE:
625+
case GGML_OP_MUL_MAT:
537626
return true;
538627

539628
default:
@@ -654,4 +743,10 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
654743
return &reg;
655744
}
656745

746+
ggml_backend_t ggml_backend_webgpu_init(void) {
747+
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
748+
749+
return ggml_backend_webgpu_device_init(dev, nullptr);
750+
}
751+
657752
GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
struct MulMatParams {
2+
m: u32,
3+
n: u32,
4+
k: u32
5+
};
6+
7+
@group(0) @binding(0) var<storage, read> src0: array<f32>;
8+
@group(0) @binding(1) var<storage, read> src1: array<f32>;
9+
@group(0) @binding(2) var<storage, read_write> dst: array<f32>;
10+
11+
@group(0) @binding(3) var<uniform> params: MulMatParams;
12+
13+
@compute @workgroup_size(64)
14+
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
15+
if (global_id.x >= params.m * params.n) {
16+
return;
17+
}
18+
let row = global_id.x / params.n;
19+
let col = global_id.x % params.n;
20+
var sum = 0.0;
21+
for (var i: u32 = 0u; i < params.k; i = i + 1u) {
22+
sum = sum + src0[col * params.k + i] * src1[row * params.k + i];
23+
}
24+
dst[row * params.n + col] = sum;
25+
}

0 commit comments

Comments
 (0)