@@ -41,10 +41,15 @@ struct webgpu_context_struct {
41
41
wgpu::Queue queue;
42
42
wgpu::Limits limits;
43
43
44
- // memset pipeline and parameter buffers
44
+ // pipelines and parameter buffers
45
+ // TODO: reuse params buffers for different pipelines when possible
45
46
wgpu::ComputePipeline memset_pipeline;
46
47
wgpu::Buffer memset_params_dev_buf;
47
48
wgpu::Buffer memset_params_host_buf;
49
+ wgpu::ComputePipeline mul_mat_pipeline;
50
+ wgpu::Buffer mul_mat_params_dev_buf;
51
+ wgpu::Buffer mul_mat_params_host_buf;
52
+
48
53
size_t memset_elems_per_thread;
49
54
50
55
// Staging buffer for reading data from the GPU
@@ -87,7 +92,7 @@ struct ggml_backend_webgpu_buffer_context {
87
92
88
93
/* WebGPU object initializations */
89
94
90
- static void ggml_webgpu_create_pipeline (wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const std::vector<wgpu::ConstantEntry> &constants = {}) {
95
+ static void ggml_webgpu_create_pipeline (wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const char * label, const std::vector<wgpu::ConstantEntry> &constants = {}) {
91
96
WEBGPU_LOG_DEBUG (" ggml_webgpu_create_pipeline()" );
92
97
wgpu::ShaderSourceWGSL shader_source;
93
98
shader_source.code = shader_code;
@@ -96,6 +101,7 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipel
96
101
wgpu::ShaderModule shader_module = device.CreateShaderModule (&shader_desc);
97
102
98
103
wgpu::ComputePipelineDescriptor pipeline_desc;
104
+ pipeline_desc.label = label;
99
105
pipeline_desc.compute .module = shader_module;
100
106
pipeline_desc.compute .entryPoint = " main" ; // Entry point in the WGSL code
101
107
pipeline_desc.layout = nullptr ; // Guessing that nullptr means auto layout
@@ -121,7 +127,7 @@ static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer
121
127
122
128
/* * WebGPU Actions */
123
129
124
- static void * ggml_backend_webgpu_map_buffer (webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) {
130
+ static void ggml_backend_webgpu_map_buffer (webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) {
125
131
ctx->instance .WaitAny (buffer.MapAsync (
126
132
mode, offset, size, wgpu::CallbackMode::WaitAnyOnly,
127
133
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
@@ -131,15 +137,14 @@ static void * ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer bu
131
137
}),
132
138
UINT64_MAX
133
139
);
134
- return buffer.GetMappedRange ();
135
140
}
136
141
137
142
static void ggml_backend_webgpu_buffer_memset (webgpu_context ctx, wgpu::Buffer buf, uint8_t value, size_t offset, size_t size) {
138
143
wgpu::Device device = ctx->device ;
139
144
140
145
// map the host parameters buffer
141
- uint32_t * params = ( uint32_t *) ggml_backend_webgpu_map_buffer (ctx, ctx->memset_params_host_buf ,
142
- wgpu::MapMode::Write, 0 , ctx->memset_params_host_buf .GetSize () );
146
+ ggml_backend_webgpu_map_buffer (ctx, ctx->memset_params_host_buf , wgpu::MapMode::Write, 0 , ctx-> memset_params_host_buf . GetSize ());
147
+ uint32_t * params = ( uint32_t *) ctx->memset_params_host_buf .GetMappedRange ( );
143
148
144
149
// This is a trick to set all bytes of a u32 to the same 1 byte value.
145
150
uint32_t val32 = (uint32_t )value * 0x01010101 ;
@@ -207,6 +212,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
207
212
208
213
WEBGPU_LOG_DEBUG (" ggml_webgpu_encode_node(" << node << " , " << ggml_op_name (node->op ) << " )" );
209
214
215
+
210
216
switch (node->op ) {
211
217
// no-op
212
218
case GGML_OP_NONE:
@@ -216,6 +222,76 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
216
222
case GGML_OP_PERMUTE:
217
223
case GGML_OP_TRANSPOSE:
218
224
return false ;
225
+
226
+ // basic matrix multiplication for now, 2d tensors only
227
+ case GGML_OP_MUL_MAT: {
228
+ const ggml_tensor * src0 = node->src [0 ];
229
+ ggml_backend_webgpu_buffer_context * src0_ctx = (ggml_backend_webgpu_buffer_context *) src0->buffer ->context ;
230
+ size_t src0_offset = webgpu_tensor_offset (src0) + src0->view_offs ;
231
+ const ggml_tensor * src1 = node->src [1 ];
232
+ ggml_backend_webgpu_buffer_context * src1_ctx = (ggml_backend_webgpu_buffer_context *) src1->buffer ->context ;
233
+ size_t src1_offset = webgpu_tensor_offset (src1) + src1->view_offs ;
234
+ ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer ->context ;
235
+
236
+ size_t dst_offset = webgpu_tensor_offset (node) + node->view_offs ;
237
+
238
+ wgpu::Device device = ctx->device ;
239
+
240
+ // map the host parameters buffer
241
+ ggml_backend_webgpu_map_buffer (ctx, ctx->mul_mat_params_host_buf ,
242
+ wgpu::MapMode::Write, 0 , ctx->mul_mat_params_host_buf .GetSize ());
243
+ uint32_t * params = (uint32_t *) ctx->mul_mat_params_host_buf .GetMappedRange ();
244
+
245
+ params[0 ] = (uint32_t )node->ne [1 ]; // number of rows in result (M)
246
+ params[1 ] = (uint32_t )node->ne [0 ]; // number of columns in result (N)
247
+ params[2 ] = (uint32_t )src0->ne [0 ]; // number of columns in src0/src1 (K)
248
+ ctx->mul_mat_params_host_buf .Unmap ();
249
+
250
+ wgpu::BindGroupEntry entries[4 ];
251
+ entries[0 ].binding = 0 ; // binding for the buffer to memset
252
+ entries[0 ].buffer = src0_ctx->buffer ;
253
+ entries[0 ].offset = src0_offset;
254
+ entries[0 ].size = ggml_nbytes (src0);
255
+
256
+ entries[1 ].binding = 1 ; // binding for the buffer to memset
257
+ entries[1 ].buffer = src1_ctx->buffer ;
258
+ entries[1 ].offset = src1_offset;
259
+ entries[1 ].size = ggml_nbytes (src1);
260
+
261
+ entries[2 ].binding = 2 ; // binding for the buffer to memset
262
+ entries[2 ].buffer = dst_ctx->buffer ;
263
+ entries[2 ].offset = dst_offset;
264
+ entries[2 ].size = ggml_nbytes (node);
265
+
266
+ entries[3 ].binding = 3 ; // binding for the parameters
267
+ entries[3 ].buffer = ctx->mul_mat_params_dev_buf ;
268
+ entries[3 ].offset = 0 ;
269
+ entries[3 ].size = ctx->mul_mat_params_dev_buf .GetSize ();
270
+
271
+ wgpu::BindGroupDescriptor bind_group_desc;
272
+ bind_group_desc.layout = ctx->mul_mat_pipeline .GetBindGroupLayout (0 );
273
+ bind_group_desc.entryCount = 4 ;
274
+ bind_group_desc.entries = entries;
275
+ wgpu::BindGroup bind_group = device.CreateBindGroup (&bind_group_desc);
276
+
277
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder ();
278
+ encoder.CopyBufferToBuffer (
279
+ ctx->mul_mat_params_host_buf , 0 ,
280
+ ctx->mul_mat_params_dev_buf , 0 ,
281
+ ctx->mul_mat_params_dev_buf .GetSize ()
282
+ );
283
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass ();
284
+ pass.SetPipeline (ctx->mul_mat_pipeline );
285
+ pass.SetBindGroup (0 , bind_group);
286
+ pass.DispatchWorkgroups (node->ne [0 ] * node->ne [1 ]);
287
+ pass.End ();
288
+ wgpu::CommandBuffer commands = encoder.Finish ();
289
+
290
+ // TODO, don't submit here, batch submissions
291
+ ctx->queue .Submit (1 , &commands);
292
+ return true ;
293
+ }
294
+
219
295
default :
220
296
return false ;
221
297
}
@@ -230,6 +306,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
230
306
for (int i = 0 ; i < cgraph->n_nodes ; i++) {
231
307
ggml_webgpu_encode_node (ctx, cgraph->nodes [i]);
232
308
}
309
+
310
+ return GGML_STATUS_SUCCESS;
233
311
}
234
312
235
313
static ggml_backend_i ggml_backend_webgpu_i = {
@@ -317,8 +395,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
317
395
webgpu_ctx->queue .Submit (1 , &commands);
318
396
319
397
// Map the staging buffer to read the data
320
- const void * mapped_range = ggml_backend_webgpu_map_buffer (webgpu_ctx, webgpu_ctx->get_tensor_staging_buf ,
321
- wgpu::MapMode::Read, 0 , size );
398
+ ggml_backend_webgpu_map_buffer (webgpu_ctx, webgpu_ctx->get_tensor_staging_buf , wgpu::MapMode::Read, 0 , size);
399
+ const void * mapped_range = webgpu_ctx-> get_tensor_staging_buf . GetConstMappedRange ( );
322
400
323
401
// Copy the data from the mapped range to the output buffer
324
402
std::memcpy (data, mapped_range, size);
@@ -439,14 +517,23 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
439
517
constants[0 ].value = max_wg_size;
440
518
constants[1 ].key = " elems_per_thread" ;
441
519
constants[1 ].value = webgpu_ctx->memset_elems_per_thread ;
442
- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->memset_pipeline , wgsl_memset, constants);
520
+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->memset_pipeline , wgsl_memset, " memset " , constants);
443
521
ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->memset_params_dev_buf ,
444
522
3 * sizeof (uint32_t ), // 3 parameters: buffer size, offset, value
445
523
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst);
446
524
ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->memset_params_host_buf ,
447
525
3 * sizeof (uint32_t ), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc);
448
526
}
449
527
528
+ static void ggml_webgpu_init_mul_mat_pipeline (webgpu_context webgpu_ctx) {
529
+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_mat_pipeline , wgsl_mul_mat, " mul_mat" );
530
+ ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->mul_mat_params_dev_buf ,
531
+ 3 * sizeof (uint32_t ), // 3 parameters: M, N, K
532
+ wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst);
533
+ ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->mul_mat_params_host_buf ,
534
+ 3 * sizeof (uint32_t ), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc);
535
+ }
536
+
450
537
// TODO: Does this need to be thread safe? Is it only called once?
451
538
static ggml_backend_t ggml_backend_webgpu_device_init (ggml_backend_dev_t dev, const char * params) {
452
539
GGML_UNUSED (params);
@@ -485,6 +572,7 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
485
572
webgpu_ctx->queue = webgpu_ctx->device .GetQueue ();
486
573
487
574
ggml_webgpu_init_memset_pipeline (webgpu_ctx);
575
+ ggml_webgpu_init_mul_mat_pipeline (webgpu_ctx);
488
576
489
577
static ggml_backend_webgpu_context backend_ctx;
490
578
backend_ctx.name = GGML_WEBGPU_NAME + std::string (" : " ) + dev_ctx->device_name ;
@@ -534,6 +622,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
534
622
case GGML_OP_VIEW:
535
623
case GGML_OP_PERMUTE:
536
624
case GGML_OP_TRANSPOSE:
625
+ case GGML_OP_MUL_MAT:
537
626
return true ;
538
627
539
628
default :
@@ -654,4 +743,10 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
654
743
return ®
655
744
}
656
745
746
+ ggml_backend_t ggml_backend_webgpu_init (void ) {
747
+ ggml_backend_dev_t dev = ggml_backend_reg_dev_get (ggml_backend_webgpu_reg (), 0 );
748
+
749
+ return ggml_backend_webgpu_device_init (dev, nullptr );
750
+ }
751
+
657
752
GGML_BACKEND_DL_IMPL (ggml_backend_webgpu_reg)
0 commit comments