16
16
#define WEBGPU_LOG_DEBUG (msg ) ((void ) 0 )
17
17
#endif // GGML_WEBGPU_DEBUG
18
18
19
+ /* Constants */
20
+
19
21
// TODO: find a better way to get the memory available
20
22
#define WEBGPU_MAX_BUFFERS 32
21
23
24
+ #define WEBGPU_MUL_MAT_WG_SIZE 64
25
+ #define WEBGPU_MUL_MAT_PARAMS_SIZE (7 * sizeof (uint32_t )) // M, N, K, batch sizes, broadcasts
26
+
27
+ /* End Constants */
28
+
22
29
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
23
30
static void * const webgpu_ptr_base = (void *)(uintptr_t ) 0x1000 ; // NOLINT
24
31
@@ -138,18 +145,16 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buff
138
145
);
139
146
}
140
147
141
- static void ggml_backend_webgpu_buffer_memset (webgpu_context ctx, wgpu::Buffer buf, uint8_t value, size_t offset, size_t size) {
148
+ static void ggml_backend_webgpu_buffer_memset (webgpu_context ctx, wgpu::Buffer buf, uint32_t value, size_t offset, size_t size) {
142
149
wgpu::Device device = ctx->device ;
143
150
144
151
// map the host parameters buffer
145
152
ggml_backend_webgpu_map_buffer (ctx, ctx->memset_params_host_buf , wgpu::MapMode::Write, 0 , ctx->memset_params_host_buf .GetSize ());
146
153
uint32_t * params = (uint32_t *) ctx->memset_params_host_buf .GetMappedRange ();
147
154
148
- // This is a trick to set all bytes of a u32 to the same 1 byte value.
149
- uint32_t val32 = (uint32_t )value * 0x01010101 ;
150
155
params[0 ] = (uint32_t )offset;
151
156
params[1 ] = (uint32_t )size;
152
- params[2 ] = val32 ;
157
+ params[2 ] = value ;
153
158
ctx->memset_params_host_buf .Unmap ();
154
159
155
160
wgpu::BindGroupEntry entries[2 ];
@@ -191,7 +196,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer b
191
196
/* * GGML Backend Interface */
192
197
193
198
static const char * ggml_backend_webgpu_name (ggml_backend_t backend) {
194
- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_name()" );
195
199
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context ;
196
200
return ctx->name .c_str ();
197
201
}
@@ -201,6 +205,7 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
201
205
WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_free(" << ctx->name << " )" );
202
206
203
207
// TODO: cleanup
208
+ GGML_UNUSED (ctx);
204
209
}
205
210
206
211
// Returns true if node has enqueued work into the queue, false otherwise
@@ -244,6 +249,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
244
249
params[0 ] = (uint32_t )node->ne [1 ]; // number of rows in result (M)
245
250
params[1 ] = (uint32_t )node->ne [0 ]; // number of columns in result (N)
246
251
params[2 ] = (uint32_t )src0->ne [0 ]; // number of columns in src0/src1 (K)
252
+ params[3 ] = (uint32_t )src0->ne [2 ]; // batch size in dimension 2
253
+ params[4 ] = (uint32_t )src0->ne [3 ]; // batch size in dimension 3
254
+ params[5 ] = (uint32_t )(src1->ne [2 ]/src0->ne [2 ]); // broadcast in dimension 2
255
+ params[6 ] = (uint32_t )(src1->ne [3 ]/src0->ne [3 ]); // broadcast in dimension 3
256
+
247
257
ctx->mul_mat_params_host_buf .Unmap ();
248
258
249
259
wgpu::BindGroupEntry entries[4 ];
@@ -282,7 +292,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
282
292
wgpu::ComputePassEncoder pass = encoder.BeginComputePass ();
283
293
pass.SetPipeline (ctx->mul_mat_pipeline );
284
294
pass.SetBindGroup (0 , bind_group);
285
- pass.DispatchWorkgroups (node->ne [0 ] * node->ne [1 ]);
295
+ pass.DispatchWorkgroups (( node->ne [0 ] * node->ne [1 ] * node-> ne [ 2 ] * node-> ne [ 3 ] + WEBGPU_MUL_MAT_WG_SIZE - 1 ) / WEBGPU_MUL_MAT_WG_SIZE );
286
296
pass.End ();
287
297
wgpu::CommandBuffer commands = encoder.Finish ();
288
298
@@ -352,7 +362,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
352
362
353
363
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context ;
354
364
size_t total_offset = webgpu_tensor_offset (tensor) + tensor->view_offs + offset;
355
- ggml_backend_webgpu_buffer_memset (buf_ctx->webgpu_ctx , buf_ctx->buffer , value, total_offset, size);
365
+ // This is a trick to set all bytes of a u32 to the same 1 byte value.
366
+ uint32_t val32 = (uint32_t )value * 0x01010101 ;
367
+ ggml_backend_webgpu_buffer_memset (buf_ctx->webgpu_ctx , buf_ctx->buffer , val32, total_offset, size);
356
368
}
357
369
358
370
static void ggml_backend_webgpu_buffer_set_tensor (ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
@@ -363,10 +375,21 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
363
375
size_t total_offset = webgpu_tensor_offset (tensor) + tensor->view_offs + offset;
364
376
365
377
// TODO: wait on this?
366
- webgpu_ctx->queue .WriteBuffer (buf_ctx->buffer , total_offset, data, size);
378
+ webgpu_ctx->queue .WriteBuffer (buf_ctx->buffer , total_offset, data, (size/4 )*4 );
379
+
380
+ if (size % 4 != 0 ) {
381
+ // If size is not a multiple of 4, we need to memset the remaining bytes
382
+ size_t remaining_size = size % 4 ;
383
+ // pack the remaining bytes into a uint32_t
384
+ uint32_t val32 = 0 ;
385
+ for (size_t i = 0 ; i < remaining_size; i++) {
386
+ ((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i];
387
+ }
388
+ // memset the remaining bytes
389
+ ggml_backend_webgpu_buffer_memset (webgpu_ctx, buf_ctx->buffer , val32, total_offset + (size - remaining_size), remaining_size);
390
+ }
367
391
}
368
392
369
- // TODO: we need a staging buffer for this, since WebGPU does not allow reading from storage buffers directly.
370
393
static void ggml_backend_webgpu_buffer_get_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
371
394
WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_get_tensor(" << buffer << " , " << tensor << " , " << data << " , " << offset << " , " << size << " )" );
372
395
@@ -376,33 +399,39 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
376
399
377
400
size_t total_offset = webgpu_tensor_offset (tensor) + tensor->view_offs + offset;
378
401
402
+ size_t final_size = size;
403
+ if (size % 4 != 0 ) {
404
+ // If size is not a multiple of 4, we need to round it up to the next multiple of 4
405
+ final_size = size + (4 - (size % 4 ));
406
+ }
407
+
379
408
if (webgpu_ctx->get_tensor_staging_buf == nullptr ||
380
- webgpu_ctx->get_tensor_staging_buf .GetSize () < size ) {
409
+ webgpu_ctx->get_tensor_staging_buf .GetSize () < final_size ) {
381
410
// Create a new staging buffer if it doesn't exist or is too small
382
411
if (webgpu_ctx->get_tensor_staging_buf ) {
383
412
webgpu_ctx->get_tensor_staging_buf .Destroy ();
384
413
}
385
- ggml_webgpu_create_buffer (device, webgpu_ctx->get_tensor_staging_buf , size ,
414
+ ggml_webgpu_create_buffer (device, webgpu_ctx->get_tensor_staging_buf , final_size ,
386
415
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
387
416
}
388
417
389
418
// Copy the data from the buffer to the staging buffer
390
419
wgpu::CommandEncoder encoder = device.CreateCommandEncoder ();
391
- encoder.CopyBufferToBuffer (buf_ctx->buffer , total_offset, webgpu_ctx->get_tensor_staging_buf , 0 , size );
420
+ encoder.CopyBufferToBuffer (buf_ctx->buffer , total_offset, webgpu_ctx->get_tensor_staging_buf , 0 , final_size );
392
421
wgpu::CommandBuffer commands = encoder.Finish ();
393
422
// Submit the command buffer to the queue
394
423
webgpu_ctx->queue .Submit (1 , &commands);
395
424
396
425
// Map the staging buffer to read the data
397
- ggml_backend_webgpu_map_buffer (webgpu_ctx, webgpu_ctx->get_tensor_staging_buf , wgpu::MapMode::Read, 0 , size);
398
- const void * mapped_range = webgpu_ctx->get_tensor_staging_buf .GetConstMappedRange ();
426
+ ggml_backend_webgpu_map_buffer (webgpu_ctx, webgpu_ctx->get_tensor_staging_buf , wgpu::MapMode::Read, 0 , final_size);
427
+ // Must specify size here since the staging buffer might be larger than the tensor size
428
+ const void * mapped_range = webgpu_ctx->get_tensor_staging_buf .GetConstMappedRange (0 , final_size);
399
429
400
430
// Copy the data from the mapped range to the output buffer
401
431
std::memcpy (data, mapped_range, size);
402
432
webgpu_ctx->get_tensor_staging_buf .Unmap ();
403
433
}
404
434
405
- // TODO
406
435
static void ggml_backend_webgpu_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value) {
407
436
WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_clear(" << buffer << " , " << value << " )" );
408
437
@@ -427,7 +456,6 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
427
456
/* GGML Backend Buffer Type Interface */
428
457
429
458
static const char * ggml_backend_webgpu_buffer_type_get_name (ggml_backend_buffer_type_t buft) {
430
- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_type_get_name()" );
431
459
ggml_backend_webgpu_device_context * ctx = static_cast <ggml_backend_webgpu_device_context *>(buft->device ->context );
432
460
return ctx->device_name .c_str ();
433
461
}
@@ -446,14 +474,12 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
446
474
}
447
475
448
476
static size_t ggml_backend_webgpu_buffer_type_get_alignment (ggml_backend_buffer_type_t buft) {
449
- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_type_get_alignment()" );
450
477
ggml_backend_webgpu_device_context * ctx = static_cast <ggml_backend_webgpu_device_context *>(buft->device ->context );
451
478
return ctx->webgpu_ctx ->limits .minStorageBufferOffsetAlignment ;
452
479
}
453
480
454
481
// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
455
482
static size_t ggml_backend_webgpu_buffer_type_get_max_size (ggml_backend_buffer_type_t buft) {
456
- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_type_get_max_size()" );
457
483
ggml_backend_webgpu_device_context * ctx = static_cast <ggml_backend_webgpu_device_context *>(buft->device ->context );
458
484
return ctx->webgpu_ctx ->limits .maxStorageBufferBindingSize ;
459
485
}
@@ -473,16 +499,13 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_
473
499
}
474
500
475
501
static void ggml_backend_webgpu_device_get_memory (ggml_backend_dev_t dev, size_t * free, size_t * total) {
476
- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_device_get_memory()" );
477
-
478
502
ggml_backend_webgpu_device_context * ctx = static_cast <ggml_backend_webgpu_device_context *>(dev->context );
479
503
// TODO: what do we actually want to return here?
480
504
*free = ctx->webgpu_ctx ->limits .maxBufferSize * WEBGPU_MAX_BUFFERS;
481
505
*total = ctx->webgpu_ctx ->limits .maxBufferSize * WEBGPU_MAX_BUFFERS;
482
506
}
483
507
484
508
static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type (ggml_backend_dev_t dev) {
485
- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_device_get_type()" );
486
509
GGML_UNUSED (dev);
487
510
return GGML_BACKEND_DEVICE_TYPE_GPU;
488
511
}
@@ -526,11 +549,10 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
526
549
527
550
static void ggml_webgpu_init_mul_mat_pipeline (webgpu_context webgpu_ctx) {
528
551
ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_mat_pipeline , wgsl_mul_mat, " mul_mat" );
529
- ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->mul_mat_params_dev_buf ,
530
- 3 * sizeof (uint32_t ), // 3 parameters: M, N, K
552
+ ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->mul_mat_params_dev_buf , WEBGPU_MUL_MAT_PARAMS_SIZE,
531
553
wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst);
532
- ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->mul_mat_params_host_buf ,
533
- 3 * sizeof ( uint32_t ), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc);
554
+ ggml_webgpu_create_buffer (webgpu_ctx->device , webgpu_ctx->mul_mat_params_host_buf ,WEBGPU_MUL_MAT_PARAMS_SIZE,
555
+ wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc);
534
556
}
535
557
536
558
// TODO: Does this need to be thread safe? Is it only called once?
@@ -617,13 +639,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
617
639
// what should we support first?
618
640
switch (op->op ) {
619
641
case GGML_OP_NONE:
620
- case GGML_OP_RESHAPE:
621
- case GGML_OP_VIEW:
622
- case GGML_OP_PERMUTE:
623
- case GGML_OP_TRANSPOSE:
624
- case GGML_OP_MUL_MAT:
625
642
return true ;
626
-
643
+ case GGML_OP_MUL_MAT:
644
+ return op->src [0 ]->type == GGML_TYPE_F32 && op->src [1 ]->type == GGML_TYPE_F32;
627
645
default :
628
646
return false ;
629
647
}
@@ -652,13 +670,11 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
652
670
/* GGML Backend Registration Interface */
653
671
654
672
static const char * ggml_backend_webgpu_reg_get_name (ggml_backend_reg_t reg) {
655
- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_reg_get_name()" );
656
673
ggml_backend_webgpu_reg_context * ctx = static_cast <ggml_backend_webgpu_reg_context *>(reg->context );
657
674
return ctx->name ;
658
675
}
659
676
660
677
static size_t ggml_backend_webgpu_reg_get_device_count (ggml_backend_reg_t reg) {
661
- WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_reg_get_device_count()" );
662
678
ggml_backend_webgpu_reg_context * ctx = static_cast <ggml_backend_webgpu_reg_context *>(reg->context );
663
679
return ctx->device_count ;
664
680
}
0 commit comments