Skip to content

Commit cbf4b96

Browse files
committed
Address feedback
1 parent 949b851 commit cbf4b96

File tree

4 files changed

+12
-10
lines changed

4 files changed

+12
-10
lines changed

docs/build.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ To read documentation for how to build on Android, [click here](./android.md)
559559

560560
## WebGPU [In Progress]
561561

562-
The WebGPU backend relies on [Dawn](https://dawn.googlesource.com/dawn). Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/quickstart-cmake.md) to install Dawn locally so that llama.cpp can find it using CMake.
562+
The WebGPU backend relies on [Dawn](https://dawn.googlesource.com/dawn). Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/quickstart-cmake.md) to install Dawn locally so that llama.cpp can find it using CMake. The currrent implementation is up-to-date with Dawn commit `bed1a61`.
563563

564564
In the llama.cpp directory, build with CMake:
565565

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
264264
uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange();
265265
uint32_t ne = (uint32_t)ggml_nelements(node);
266266
params[0] = ne;
267-
params[1] = src_misalignment;
268-
params[2] = dst_misalignment;
267+
params[1] = src_misalignment/ggml_type_size(src->type);
268+
params[2] = dst_misalignment/ggml_type_size(node->type);
269269

270270
// Convert byte-strides to element-strides
271271
params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type);
@@ -881,10 +881,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
881881
ctx.name = GGML_WEBGPU_NAME;
882882
ctx.device_count = 1;
883883

884-
885-
wgpu::InstanceDescriptor instanceDescriptor{};
886-
instanceDescriptor.capabilities.timedWaitAnyEnable = true;
887-
webgpu_ctx->instance = wgpu::CreateInstance(&instanceDescriptor);
884+
wgpu::InstanceDescriptor instance_descriptor{};
885+
std::vector<wgpu::InstanceFeatureName> instance_features = {wgpu::InstanceFeatureName::TimedWaitAny};
886+
instance_descriptor.requiredFeatures = instance_features.data();
887+
instance_descriptor.requiredFeatureCount = instance_features.size();
888+
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
888889
GGML_ASSERT(webgpu_ctx->instance != nullptr);
889890

890891
static ggml_backend_reg reg = {

ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ var<storage, read_write> dst: array<f16>;
88

99
struct Params {
1010
ne: u32, // total number of elements
11-
offset_src: u32, // in bytes
12-
offset_dst: u32, // in bytes
11+
offset_src: u32, // in elements
12+
offset_dst: u32, // in elements
1313

1414
// Strides (in elements) — may be permuted
1515
stride_src0: u32,

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ struct MulMatParams {
22
m: u32,
33
n: u32,
44
k: u32,
5+
// all strides are in elements
56
stride_01: u32,
67
stride_11: u32,
78
stride_02: u32,
@@ -16,7 +17,7 @@ struct MulMatParams {
1617
};
1718

1819
@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // N rows, K columns
19-
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // M rows, K columns
20+
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // M rows, K columns (transposed)
2021
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
2122

2223
@group(0) @binding(3) var<uniform> params: MulMatParams;

0 commit comments

Comments
 (0)