Skip to content

Introduction of parameterizable workgroups for compute shaders and dispatch sizes for the renderer #31402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: dev
Choose a base branch
from
Open
25 changes: 19 additions & 6 deletions src/nodes/gpgpu/ComputeNode.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class ComputeNode extends Node {
*
* @param {Node} computeNode - TODO
* @param {number} count - TODO.
* @param {Array<number>} [workgroupSize=[64]] - TODO.
* @param {Array<number>} [workgroupSize = [ 64, 1, 1 ]]
*/
constructor( computeNode, count, workgroupSize = [ 64 ] ) {
constructor( computeNode, count, workgroupSize = [ 64, 1, 1 ] ) {

super( 'void' );

Expand Down Expand Up @@ -53,7 +53,7 @@ class ComputeNode extends Node {
* TODO
*
* @type {Array<number>}
* @default [64]
* @default [ 64, 1, 1 ]
*/
this.workgroupSize = workgroupSize;

Expand Down Expand Up @@ -219,10 +219,23 @@ export default ComputeNode;
* @tsl
* @function
* @param {Node} node - TODO
* @param {number} count - TODO.
* @param {Array<number>} [workgroupSize=[64]] - TODO.
* @param {number} countOrWorkgroupSize - TODO.
* @param {Array<number>} [workgroupSize=[ 64, 1, 1 ]]
* @returns {AtomicFunctionNode}
*/
export const compute = ( node, count, workgroupSize ) => nodeObject( new ComputeNode( nodeObject( node ), count, workgroupSize ) );
export const compute = ( node, countOrWorkgroupSize, workgroupSize ) => {

let count = countOrWorkgroupSize;

if ( Array.isArray( countOrWorkgroupSize ) ) {

workgroupSize = countOrWorkgroupSize;
count = null;

}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the project agrees with the API change I'm wondering if we should deprecate the old code path where "count" is passed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've set it up so that count is used if no workgroup is specified, so that existing user code and examples continue to run as before. But yes, I would also be in favor of removing the count from the code in the long term, with the usual transition time. This also significantly simplifies the compute call in the backend.


return nodeObject( new ComputeNode( nodeObject( node ), count, workgroupSize ) );

};

addMethodChaining( 'compute', compute );
5 changes: 3 additions & 2 deletions src/renderers/common/Renderer.js
Original file line number Diff line number Diff line change
Expand Up @@ -2308,9 +2308,10 @@ class Renderer {
* if the renderer has been initialized.
*
* @param {Node|Array<Node>} computeNodes - The compute node(s).
* @param {Array<number>} dispatchSize - Array with [ x,y,z ] values for dispatch.
* @return {Promise|undefined} A Promise that resolve when the compute has finished. Only returned when the renderer has not been initialized.
*/
compute( computeNodes ) {
compute( computeNodes, dispatchSize = [ 0, 0, 0 ] ) {

if ( this._isDeviceLost === true ) return;

Expand Down Expand Up @@ -2389,7 +2390,7 @@ class Renderer {
const computeBindings = bindings.getForCompute( computeNode );
const computePipeline = pipelines.getForCompute( computeNode, computeBindings );

backend.compute( computeNodes, computeNode, computeBindings, computePipeline );
backend.compute( computeNodes, computeNode, computeBindings, computePipeline, dispatchSize );

}

Expand Down
45 changes: 29 additions & 16 deletions src/renderers/webgpu/WebGPUBackend.js
Original file line number Diff line number Diff line change
Expand Up @@ -1318,10 +1318,15 @@ class WebGPUBackend extends Backend {
* @param {Node} computeNode - The compute node.
* @param {Array<BindGroup>} bindings - The bindings.
* @param {ComputePipeline} pipeline - The compute pipeline.
* @param {Array<number>} dispatchSize - Array with [x,y,z] values for dispatch.
*/
compute( computeGroup, computeNode, bindings, pipeline ) {
compute( computeGroup, computeNode, bindings, pipeline, dispatchSize ) {

const computeNodeData = this.get( computeNode );
const { passEncoderGPU } = this.get( computeGroup );
const isValid = dispatchSize[ 0 ] > 0 && dispatchSize[ 1 ] > 0 && dispatchSize[ 2 ] > 0;

dispatchSize = isValid ? dispatchSize : computeNodeData;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's valid to set @workgroup_size to 1 or 2 values when defining the parameters inline with the remaining dimensions implicitly being set to "1" (see the spec here) so it would be nice to allow for providing workgroup size of 1 or 2 to align with the native behavior and for the sake of ergonomics:

const computeShader = wgslFn( `fn fragmentShader() -> void {}` );

// all are equivalant
const kernel = computeShader().compute( [ 16 ] );
const kernel = computeShader().compute( [ 16, 1 ] );
const kernel = computeShader().compute( [ 16, 1, 1 ] );

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know that from the W3C documentation, I'll take a look at it, thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be implemented in two ways. Since it makes no difference in performance, I prefer to check the length of the workgroupSize array before calling the wgsl code builder to fill it with ones if smaller than a three-valued array


// pipeline

Expand All @@ -1340,30 +1345,38 @@ class WebGPUBackend extends Backend {

}

const maxComputeWorkgroupsPerDimension = this.device.limits.maxComputeWorkgroupsPerDimension;
if ( isValid ) {

const computeNodeData = this.get( computeNode );
passEncoderGPU.dispatchWorkgroups(
dispatchSize[ 0 ],
dispatchSize[ 1 ],
dispatchSize[ 2 ]
);

if ( computeNodeData.dispatchSize === undefined ) computeNodeData.dispatchSize = { x: 0, y: 1, z: 1 };
} else {

const { dispatchSize } = computeNodeData;
const maxComputeWorkgroupsPerDimension = this.device.limits.maxComputeWorkgroupsPerDimension;

if ( computeNode.dispatchCount > maxComputeWorkgroupsPerDimension ) {
if ( computeNodeData.dispatchSize === undefined ) computeNodeData.dispatchSize = { x: 0, y: 1, z: 1 };

dispatchSize.x = Math.min( computeNode.dispatchCount, maxComputeWorkgroupsPerDimension );
dispatchSize.y = Math.ceil( computeNode.dispatchCount / maxComputeWorkgroupsPerDimension );
if ( computeNode.dispatchCount > maxComputeWorkgroupsPerDimension ) {

} else {
dispatchSize.x = Math.min( computeNode.dispatchCount, maxComputeWorkgroupsPerDimension );
dispatchSize.y = Math.ceil( computeNode.dispatchCount / maxComputeWorkgroupsPerDimension );

} else {

dispatchSize.x = computeNode.dispatchCount;
dispatchSize.x = computeNode.dispatchCount;

}
}

passEncoderGPU.dispatchWorkgroups(
dispatchSize.x,
dispatchSize.y,
dispatchSize.z
);
passEncoderGPU.dispatchWorkgroups(
dispatchSize.x,
dispatchSize.y,
dispatchSize.z
);

}

}

Expand Down
14 changes: 11 additions & 3 deletions src/renderers/webgpu/nodes/WGSLNodeBuilder.js
Original file line number Diff line number Diff line change
Expand Up @@ -1898,7 +1898,11 @@ ${ flowData.code }

} else {

this.computeShader = this._getWGSLComputeCode( shadersData.compute, ( this.object.workgroupSize || [ 64 ] ).join( ', ' ) );
const workgroupSize = this.object.workgroupSize || [ 64, 1, 1 ];

if ( workgroupSize.length !== 3 ) throw new Error( 'workgroupSize must have 3 elements' );
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise dispatchWorkgroups can also take 1 or 2 parameters with the remainder defaulting to 1:

// all equivelant
renderer.computeAsync( kernel, [ 10 ] );
renderer.computeAsync( kernel, [ 10, 1 ] );
renderer.computeAsync( kernel, [ 10, 1, 1 ] );

I'm curious to hear other opinions

Copy link
Contributor Author

@Spiri0 Spiri0 Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's simple. I can do it tonight ( CET ). If you have any other requests by then, I can include them as well


this.computeShader = this._getWGSLComputeCode( shadersData.compute, workgroupSize );

}

Expand Down Expand Up @@ -2103,6 +2107,8 @@ fn main( ${shaderData.varyings} ) -> ${shaderData.returnType} {
*/
_getWGSLComputeCode( shaderData, workgroupSize ) {

const [ workgroupSizeX, workgroupSizeY, workgroupSizeZ ] = workgroupSize;

return `${ this.getSignature() }
// directives
${shaderData.directives}
Expand All @@ -2122,11 +2128,13 @@ ${shaderData.uniforms}
// codes
${shaderData.codes}

@compute @workgroup_size( ${workgroupSize} )
@compute @workgroup_size( ${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ} )
fn main( ${shaderData.attributes} ) {

// system
instanceIndex = globalId.x + globalId.y * numWorkgroups.x * u32(${workgroupSize}) + globalId.z * numWorkgroups.x * numWorkgroups.y * u32(${workgroupSize});
instanceIndex = globalId.x
+ globalId.y * (${workgroupSizeX} * numWorkgroups.x)
+ globalId.z * (${workgroupSizeX} * numWorkgroups.x) * (${workgroupSizeY} * numWorkgroups.y);

// vars
${shaderData.vars}
Expand Down