diff --git a/src/nodes/gpgpu/ComputeNode.js b/src/nodes/gpgpu/ComputeNode.js index ec64913b2e43e0..4bb89f8f37e76a 100644 --- a/src/nodes/gpgpu/ComputeNode.js +++ b/src/nodes/gpgpu/ComputeNode.js @@ -20,9 +20,9 @@ class ComputeNode extends Node { * * @param {Node} computeNode - TODO * @param {number} count - TODO. - * @param {Array} [workgroupSize=[64]] - TODO. + * @param {Array} [workgroupSize = [ 64, 1, 1 ]] */ - constructor( computeNode, count, workgroupSize = [ 64 ] ) { + constructor( computeNode, count, workgroupSize = [ 64, 1, 1 ] ) { super( 'void' ); @@ -53,7 +53,7 @@ class ComputeNode extends Node { * TODO * * @type {Array} - * @default [64] + * @default [ 64, 1, 1 ] */ this.workgroupSize = workgroupSize; @@ -219,10 +219,47 @@ export default ComputeNode; * @tsl * @function * @param {Node} node - TODO - * @param {number} count - TODO. - * @param {Array} [workgroupSize=[64]] - TODO. + * @param {number} countOrWorkgroupSize - TODO, depends on the future of count * @returns {AtomicFunctionNode} */ -export const compute = ( node, count, workgroupSize ) => nodeObject( new ComputeNode( nodeObject( node ), count, workgroupSize ) ); +export const compute = ( node, countOrWorkgroupSize ) => { + + let count = null; + let workgroupSize = [ 64, 1, 1 ]; //default + + if ( Array.isArray( countOrWorkgroupSize ) ) { + + workgroupSize = countOrWorkgroupSize; + + if ( workgroupSize.length === 0 || workgroupSize.length > 3 ) { + + throw new Error( 'workgroupSize must have 1, 2, or 3 elements' ); + + } + + for ( let i = 0; i < workgroupSize.length; i ++ ) { + + const val = workgroupSize[ i ]; + + if ( typeof val !== 'number' || val <= 0 || ! Number.isInteger( val ) ) { + + throw new Error( `workgroupSize element at index ${i} must be a positive integer` ); + + } + + } + + // Implicit fill-up to [ x, y, z ] with 1s, just like WGSL treats @workgroup_size when fewer dimensions are specified + while ( workgroupSize.length < 3 ) workgroupSize.push( 1 ); + + } else { + + count = countOrWorkgroupSize; + + } + + return nodeObject( new ComputeNode( nodeObject( node ), count, workgroupSize ) ); + +}; addMethodChaining( 'compute', compute ); diff --git a/src/renderers/common/Renderer.js b/src/renderers/common/Renderer.js index c06fe9b02d4c94..98ac6a2bce2fea 100644 --- a/src/renderers/common/Renderer.js +++ b/src/renderers/common/Renderer.js @@ -2308,9 +2308,10 @@ class Renderer { * if the renderer has been initialized. * * @param {Node|Array} computeNodes - The compute node(s). + * @param {Array} dispatchSize - Array with [ x,y,z ] values for dispatch. Default = null * @return {Promise|undefined} A Promise that resolve when the compute has finished. Only returned when the renderer has not been initialized. */ - compute( computeNodes ) { + compute( computeNodes, dispatchSize = null ) { if ( this._isDeviceLost === true ) return; @@ -2389,7 +2390,7 @@ class Renderer { const computeBindings = bindings.getForCompute( computeNode ); const computePipeline = pipelines.getForCompute( computeNode, computeBindings ); - backend.compute( computeNodes, computeNode, computeBindings, computePipeline ); + backend.compute( computeNodes, computeNode, computeBindings, computePipeline, dispatchSize ); } diff --git a/src/renderers/webgpu/WebGPUBackend.js b/src/renderers/webgpu/WebGPUBackend.js index 84ce1d8a07c65f..7d15d935500af6 100644 --- a/src/renderers/webgpu/WebGPUBackend.js +++ b/src/renderers/webgpu/WebGPUBackend.js @@ -1318,9 +1318,11 @@ class WebGPUBackend extends Backend { * @param {Node} computeNode - The compute node. * @param {Array} bindings - The bindings. * @param {ComputePipeline} pipeline - The compute pipeline. + * @param {Array} dispatchSize - Array with [x,y,z] values for dispatch. */ - compute( computeGroup, computeNode, bindings, pipeline ) { + compute( computeGroup, computeNode, bindings, pipeline, dispatchSize ) { + const computeNodeData = this.get( computeNode ); const { passEncoderGPU } = this.get( computeGroup ); // pipeline @@ -1340,30 +1342,66 @@ class WebGPUBackend extends Backend { } - const maxComputeWorkgroupsPerDimension = this.device.limits.maxComputeWorkgroupsPerDimension; + if ( dispatchSize !== null ) { - const computeNodeData = this.get( computeNode ); + if ( ! Array.isArray( dispatchSize ) ) { + + throw new Error( 'dispatchSize must be an array' ); + + } + + if ( dispatchSize.length === 0 || dispatchSize.length > 3 ) { - if ( computeNodeData.dispatchSize === undefined ) computeNodeData.dispatchSize = { x: 0, y: 1, z: 1 }; + throw new Error( 'dispatchSize must have 1, 2, or 3 elements' ); + + } - const { dispatchSize } = computeNodeData; + for ( let i = 0; i < dispatchSize.length; i ++ ) { - if ( computeNode.dispatchCount > maxComputeWorkgroupsPerDimension ) { + const value = dispatchSize[ i ]; - dispatchSize.x = Math.min( computeNode.dispatchCount, maxComputeWorkgroupsPerDimension ); - dispatchSize.y = Math.ceil( computeNode.dispatchCount / maxComputeWorkgroupsPerDimension ); + if ( typeof value !== 'number' || value <= 0 || ! Number.isInteger( value ) ) { + + throw new Error( `dispatchSize element at index ${i} must be a positive integer` ); + + } + + } + + while ( dispatchSize.length < 3 ) dispatchSize.push( 1 ); + + passEncoderGPU.dispatchWorkgroups( + dispatchSize[ 0 ], + dispatchSize[ 1 ], + dispatchSize[ 2 ] + ); } else { - dispatchSize.x = computeNode.dispatchCount; + dispatchSize = computeNodeData; - } + const maxComputeWorkgroupsPerDimension = this.device.limits.maxComputeWorkgroupsPerDimension; - passEncoderGPU.dispatchWorkgroups( - dispatchSize.x, - dispatchSize.y, - dispatchSize.z - ); + if ( computeNodeData.dispatchSize === undefined ) computeNodeData.dispatchSize = { x: 0, y: 1, z: 1 }; + + if ( computeNode.dispatchCount > maxComputeWorkgroupsPerDimension ) { + + dispatchSize.x = Math.min( computeNode.dispatchCount, maxComputeWorkgroupsPerDimension ); + dispatchSize.y = Math.ceil( computeNode.dispatchCount / maxComputeWorkgroupsPerDimension ); + + } else { + + dispatchSize.x = computeNode.dispatchCount; + + } + + passEncoderGPU.dispatchWorkgroups( + dispatchSize.x, + dispatchSize.y, + dispatchSize.z + ); + + } } diff --git a/src/renderers/webgpu/nodes/WGSLNodeBuilder.js b/src/renderers/webgpu/nodes/WGSLNodeBuilder.js index f67f828f487d95..ec4bceeb321b1a 100644 --- a/src/renderers/webgpu/nodes/WGSLNodeBuilder.js +++ b/src/renderers/webgpu/nodes/WGSLNodeBuilder.js @@ -1898,7 +1898,9 @@ ${ flowData.code } } else { - this.computeShader = this._getWGSLComputeCode( shadersData.compute, ( this.object.workgroupSize || [ 64 ] ).join( ', ' ) ); + const workgroupSize = this.object.workgroupSize; //early strictly validated in computeNode + + this.computeShader = this._getWGSLComputeCode( shadersData.compute, workgroupSize ); } @@ -2103,6 +2105,8 @@ fn main( ${shaderData.varyings} ) -> ${shaderData.returnType} { */ _getWGSLComputeCode( shaderData, workgroupSize ) { + const [ workgroupSizeX, workgroupSizeY, workgroupSizeZ ] = workgroupSize; + return `${ this.getSignature() } // directives ${shaderData.directives} @@ -2122,11 +2126,13 @@ ${shaderData.uniforms} // codes ${shaderData.codes} -@compute @workgroup_size( ${workgroupSize} ) +@compute @workgroup_size( ${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ} ) fn main( ${shaderData.attributes} ) { // system - instanceIndex = globalId.x + globalId.y * numWorkgroups.x * u32(${workgroupSize}) + globalId.z * numWorkgroups.x * numWorkgroups.y * u32(${workgroupSize}); + instanceIndex = globalId.x + + globalId.y * (${workgroupSizeX} * numWorkgroups.x) + + globalId.z * (${workgroupSizeX} * numWorkgroups.x) * (${workgroupSizeY} * numWorkgroups.y); // vars ${shaderData.vars}