Skip to content

​[WebGPU] Gradients are always zero for RNN models on Pixel 10 Pro #8590

@shin2sasa

Description

@shin2sasa

Environment
​Device: Google Pixel 10 Pro
​OS: Android 16 (Build: BD3A.250721.001.B7)
​Browser: Chrome 140.0.7339.52
​TensorFlow.js Packages:
​@tensorflow/tfjs: 4.22.0
​@tensorflow/tfjs-backend-webgpu: 4.22.0
​Description
​On the Google Pixel 10 Pro with the WebGPU backend, the backpropagation step for RNN models consistently produces zero gradients. This prevents the model from learning, as its weights are never updated.
​Key observations:
​The forward pass works correctly. Operations like tf.matMul return the expected results, even when changing the WEBGPU_MATMUL_PROGRAM_TYPE flag.
​The issue is isolated to the gradient calculation during the backward pass.
​The exact same code runs perfectly on the WebGL backend on the same device.
​The issue appears to be specific to the Pixel 10 Pro, as other WebGPU-enabled Android devices do not exhibit this behavior.
​Steps to Reproduce
​The following code demonstrates that gradients for a simpleRNN model are all zero.

await tf.ready();
await tf.setBackend('webgpu');

const maxlen = 10;
const testInput = tf.randomNormal([1, maxlen, 1]);

const model = tf.sequential();
model.add(tf.layers.simpleRNN({ units: 20, inputShape: [maxlen, 1] }));
model.add(tf.layers.dense({ units: 1 }));

model.compile({ loss: 'meanSquaredError', optimizer: tf.train.adam(0.001) });

// Calculate gradients
const { grads } = model.optimizer.computeGradients(() => model.predict(testInput).mean());

// Print gradients to the console
console.log("Calculated Gradients:");
for (const name in grads) {
console.log(name);
// Use async .array() to avoid warnings and follow best practices
const gradValues = await grads[name].array();
console.log(gradValues);
}

Actual Result
​All calculated gradients are zero.

simple_rnn_SimpleRNN1/kernel
[[0, 0, 0, ...], [0, 0, 0, ...], ...]

simple_rnn_SimpleRNN1/recurrent_kernel
[[0, 0, 0, ...], [0, 0, 0, ...], ...]

simple_rnn_SimpleRNN1/bias
[0, 0, 0, ...]

dense_Dense1/kernel
[[0], [0], [0], ...]

dense_Dense1/bias
[0]

Expected Result
​Gradients should be non-zero values, allowing the model to learn. This is the behavior observed on the WebGL backend and on other devices.

Additional Context
​GPU Adapter Information
​The GPU vendor is img-tec. This may be relevant as the issue could be related to specific driver implementations for this hardware.

Click to see full adapter, limits, and features info

// --- navigator.gpu.requestAdapter().info ---
{
"vendor": "img-tec",
"architecture": "",
"device": "",
"description": "",
"subgroupMinSize": 4,
"subgroupMaxSize": 128,
"isFallbackAdapter": false
}

// --- navigator.gpu.requestAdapter().limits ---
{
"maxTextureDimension1D": 16384,
"maxTextureDimension2D": 16384,
"maxTextureDimension3D": 2048,
"maxTextureArrayLayers": 2048,
"maxBindGroups": 4,
"maxBindingsPerBindGroup": 1000,
"maxDynamicUniformBuffersPerPipelineLayout": 10,
"maxDynamicStorageBuffersPerPipelineLayout": 8,
"maxSampledTexturesPerShaderStage": 16,
"maxSamplersPerShaderStage": 16,
"maxStorageBuffersPerShaderStage": 10,
"maxStorageTexturesPerShaderStage": 8,
"maxUniformBuffersPerShaderStage": 12,
"maxUniformBufferBindingSize": 65536,
"maxStorageBufferBindingSize": 134217728,
"minUniformBufferOffsetAlignment": 256,
"minStorageBufferOffsetAlignment": 256,
"maxVertexBuffers": 8,
"maxBufferSize": 2147483648,
"maxVertexAttributes": 16,
"maxVertexBufferArrayStride": 2048,
"maxInterStageShaderVariables": 28,
"maxColorAttachments": 8,
"maxColorAttachmentBytesPerSample": 32,
"maxComputeWorkgroupStorageSize": 32768,
"maxComputeInvocationsPerWorkgroup": 1024,
"maxComputeWorkgroupSizeX": 1024,
"maxComputeWorkgroupSizeY": 1024,
"maxComputeWorkgroupSizeZ": 64,
"maxComputeWorkgroupsPerDimension": 65535,
"maxBindGroupsPlusVertexBuffers": 24
}

// --- navigator.gpu.requestAdapter().features ---
// (Set with 13 features)

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions