Skip to content

Commit a8daa8f

Browse files
author
Attila Schroeder
committed
Allows the specification of workgroups for compute shaders and the specification of the dispatchSize
1 parent 2f1c6c0 commit a8daa8f

File tree

4 files changed

+62
-47
lines changed

4 files changed

+62
-47
lines changed

src/nodes/gpgpu/ComputeNode.js

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ class ComputeNode extends Node {
2020
*
2121
* @param {Node} computeNode - TODO
2222
* @param {number} count - TODO.
23-
* @param {Array<number>} [workgroupSize=[64]] - TODO.
23+
* @param {Array<number>} [workgroupSize = [ 64, 1, 1 ]]
2424
*/
25-
constructor( computeNode, count, workgroupSize = [ 64 ] ) {
25+
constructor( computeNode, count, workgroupSize = [ 64, 1, 1 ] ) {
2626

2727
super( 'void' );
2828

@@ -53,7 +53,7 @@ class ComputeNode extends Node {
5353
* TODO
5454
*
5555
* @type {Array<number>}
56-
* @default [64]
56+
* @default [ 64, 1, 1 ]
5757
*/
5858
this.workgroupSize = workgroupSize;
5959

@@ -220,9 +220,22 @@ export default ComputeNode;
220220
* @function
221221
* @param {Node} node - TODO
222222
* @param {number} count - TODO.
223-
* @param {Array<number>} [workgroupSize=[64]] - TODO.
223+
* @param {Array<number>} [workgroupSize=[ 64, 1, 1 ]]
224224
* @returns {AtomicFunctionNode}
225225
*/
226-
export const compute = ( node, count, workgroupSize ) => nodeObject( new ComputeNode( nodeObject( node ), count, workgroupSize ) );
226+
export const compute = ( node, countOrWorkgroupSize, workgroupSize ) => {
227+
228+
let count = countOrWorkgroupSize;
229+
230+
if ( Array.isArray( countOrWorkgroupSize ) ) {
231+
232+
workgroupSize = countOrWorkgroupSize;
233+
count = null;
234+
235+
}
236+
237+
return nodeObject( new ComputeNode( nodeObject( node ), count, workgroupSize ) );
238+
239+
};
227240

228241
addMethodChaining( 'compute', compute );

src/renderers/common/Renderer.js

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,24 +1243,13 @@ class Renderer {
12431243

12441244
frameBufferTarget.depthBuffer = depth;
12451245
frameBufferTarget.stencilBuffer = stencil;
1246-
if ( outputRenderTarget !== null ) {
1247-
1248-
frameBufferTarget.setSize( outputRenderTarget.width, outputRenderTarget.height, outputRenderTarget.depth );
1249-
1250-
} else {
1251-
1252-
frameBufferTarget.setSize( width, height, 1 );
1253-
1254-
}
1255-
1246+
frameBufferTarget.setSize( width, height, outputRenderTarget !== null ? outputRenderTarget.depth : 1 );
12561247
frameBufferTarget.viewport.copy( this._viewport );
12571248
frameBufferTarget.scissor.copy( this._scissor );
12581249
frameBufferTarget.viewport.multiplyScalar( this._pixelRatio );
12591250
frameBufferTarget.scissor.multiplyScalar( this._pixelRatio );
12601251
frameBufferTarget.scissorTest = this._scissorTest;
12611252
frameBufferTarget.multiview = outputRenderTarget !== null ? outputRenderTarget.multiview : false;
1262-
frameBufferTarget.resolveDepthBuffer = outputRenderTarget !== null ? outputRenderTarget.resolveDepthBuffer : true;
1263-
frameBufferTarget._autoAllocateDepthBuffer = outputRenderTarget !== null ? outputRenderTarget._autoAllocateDepthBuffer : false;
12641253

12651254
return frameBufferTarget;
12661255

@@ -1516,15 +1505,6 @@ class Renderer {
15161505

15171506
}
15181507

1519-
_setXRLayerSize( width, height ) {
1520-
1521-
this._width = width;
1522-
this._height = height;
1523-
1524-
this.setViewport( 0, 0, width, height );
1525-
1526-
}
1527-
15281508
/**
15291509
* The output pass performs tone mapping and color space conversion.
15301510
*
@@ -2310,7 +2290,7 @@ class Renderer {
23102290
* @param {Node|Array<Node>} computeNodes - The compute node(s).
23112291
* @return {Promise|undefined} A Promise that resolve when the compute has finished. Only returned when the renderer has not been initialized.
23122292
*/
2313-
compute( computeNodes ) {
2293+
compute( computeNodes, dispatchSize = [ 0, 0, 0 ] ) {
23142294

23152295
if ( this._isDeviceLost === true ) return;
23162296

@@ -2389,7 +2369,7 @@ class Renderer {
23892369
const computeBindings = bindings.getForCompute( computeNode );
23902370
const computePipeline = pipelines.getForCompute( computeNode, computeBindings );
23912371

2392-
backend.compute( computeNodes, computeNode, computeBindings, computePipeline );
2372+
backend.compute( computeNodes, computeNode, computeBindings, computePipeline, dispatchSize );
23932373

23942374
}
23952375

src/renderers/webgpu/WebGPUBackend.js

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,9 +1319,13 @@ class WebGPUBackend extends Backend {
13191319
* @param {Array<BindGroup>} bindings - The bindings.
13201320
* @param {ComputePipeline} pipeline - The compute pipeline.
13211321
*/
1322-
compute( computeGroup, computeNode, bindings, pipeline ) {
1322+
compute( computeGroup, computeNode, bindings, pipeline, dispatchSize ) {
13231323

1324+
const computeNodeData = this.get( computeNode );
13241325
const { passEncoderGPU } = this.get( computeGroup );
1326+
const isValid = dispatchSize[ 0 ] > 0 && dispatchSize[ 1 ] > 0 && dispatchSize[ 2 ] > 0;
1327+
1328+
dispatchSize = isValid ? dispatchSize : computeNodeData;
13251329

13261330
// pipeline
13271331

@@ -1340,30 +1344,38 @@ class WebGPUBackend extends Backend {
13401344

13411345
}
13421346

1343-
const maxComputeWorkgroupsPerDimension = this.device.limits.maxComputeWorkgroupsPerDimension;
1347+
if ( isValid ) {
13441348

1345-
const computeNodeData = this.get( computeNode );
1349+
passEncoderGPU.dispatchWorkgroups(
1350+
dispatchSize[ 0 ],
1351+
dispatchSize[ 1 ],
1352+
dispatchSize[ 2 ]
1353+
);
13461354

1347-
if ( computeNodeData.dispatchSize === undefined ) computeNodeData.dispatchSize = { x: 0, y: 1, z: 1 };
1355+
} else {
13481356

1349-
const { dispatchSize } = computeNodeData;
1357+
const maxComputeWorkgroupsPerDimension = this.device.limits.maxComputeWorkgroupsPerDimension;
13501358

1351-
if ( computeNode.dispatchCount > maxComputeWorkgroupsPerDimension ) {
1359+
if ( computeNodeData.dispatchSize === undefined ) computeNodeData.dispatchSize = { x: 0, y: 1, z: 1 };
13521360

1353-
dispatchSize.x = Math.min( computeNode.dispatchCount, maxComputeWorkgroupsPerDimension );
1354-
dispatchSize.y = Math.ceil( computeNode.dispatchCount / maxComputeWorkgroupsPerDimension );
1361+
if ( computeNode.dispatchCount > maxComputeWorkgroupsPerDimension ) {
13551362

1356-
} else {
1363+
dispatchSize.x = Math.min( computeNode.dispatchCount, maxComputeWorkgroupsPerDimension );
1364+
dispatchSize.y = Math.ceil( computeNode.dispatchCount / maxComputeWorkgroupsPerDimension );
1365+
1366+
} else {
13571367

1358-
dispatchSize.x = computeNode.dispatchCount;
1368+
dispatchSize.x = computeNode.dispatchCount;
13591369

1360-
}
1370+
}
13611371

1362-
passEncoderGPU.dispatchWorkgroups(
1363-
dispatchSize.x,
1364-
dispatchSize.y,
1365-
dispatchSize.z
1366-
);
1372+
passEncoderGPU.dispatchWorkgroups(
1373+
dispatchSize.x,
1374+
dispatchSize.y,
1375+
dispatchSize.z
1376+
);
1377+
1378+
}
13671379

13681380
}
13691381

src/renderers/webgpu/nodes/WGSLNodeBuilder.js

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,7 +1898,13 @@ ${ flowData.code }
18981898

18991899
} else {
19001900

1901-
this.computeShader = this._getWGSLComputeCode( shadersData.compute, ( this.object.workgroupSize || [ 64 ] ).join( ', ' ) );
1901+
//this.computeShader = this._getWGSLComputeCode( shadersData.compute, ( this.object.workgroupSize || [ 64 ] ).join( ', ' ) );
1902+
1903+
const workgroupSize = this.object.workgroupSize || [ 8, 8, 1 ];
1904+
1905+
if ( workgroupSize.length !== 3 ) throw new Error( "workgroupSize must have 3 elements" );
1906+
1907+
this.computeShader = this._getWGSLComputeCode( shadersData.compute, workgroupSize );
19021908

19031909
}
19041910

@@ -2103,6 +2109,8 @@ fn main( ${shaderData.varyings} ) -> ${shaderData.returnType} {
21032109
*/
21042110
_getWGSLComputeCode( shaderData, workgroupSize ) {
21052111

2112+
const [ workgroupSizeX, workgroupSizeY, workgroupSizeZ ] = workgroupSize;
2113+
21062114
return `${ this.getSignature() }
21072115
// directives
21082116
${shaderData.directives}
@@ -2122,11 +2130,13 @@ ${shaderData.uniforms}
21222130
// codes
21232131
${shaderData.codes}
21242132
2125-
@compute @workgroup_size( ${workgroupSize} )
2133+
@compute @workgroup_size( ${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ} )
21262134
fn main( ${shaderData.attributes} ) {
21272135
21282136
// system
2129-
instanceIndex = globalId.x + globalId.y * numWorkgroups.x * u32(${workgroupSize}) + globalId.z * numWorkgroups.x * numWorkgroups.y * u32(${workgroupSize});
2137+
instanceIndex = globalId.x
2138+
+ globalId.y * (${workgroupSizeX} * numWorkgroups.x)
2139+
+ globalId.z * (${workgroupSizeX} * numWorkgroups.x) * (${workgroupSizeY} * numWorkgroups.y);
21302140
21312141
// vars
21322142
${shaderData.vars}

0 commit comments

Comments
 (0)