Skip to content

Commit b9d2e24

Browse files
committed
webgpu: Reduce binary ops shader variants
PERF This PR aims to reduce the warmup time by reducing binary ops shader variants. It may slightly hurt the inference time, but can greatly improve the warmup time. bodypix-mobilenet reduced ~300ms in the first pass on CFL.
1 parent 69858e1 commit b9d2e24

File tree

5 files changed

+35
-45
lines changed

5 files changed

+35
-45
lines changed

tfjs-backend-webgpu/src/kernels/Prelu.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ import {KernelConfig, KernelFunc, Prelu, PreluInputs, TensorInfo} from '@tensorf
2020
import {WebGPUBackend} from '../backend_webgpu';
2121

2222
import {BinaryOpType} from './binary_op_util';
23-
import {BinaryOpProgram} from './binary_op_webgpu';
23+
import {getBinaryProgram} from './binary_ops';
2424

2525
export function prelu(args: {inputs: PreluInputs, backend: WebGPUBackend}):
2626
TensorInfo {
2727
const {inputs, backend} = args;
2828
const {x, alpha} = inputs;
2929

30-
const program = new BinaryOpProgram(BinaryOpType.PRELU, x.shape, alpha.shape);
30+
const program = getBinaryProgram(BinaryOpType.PRELU, x.shape, alpha.shape);
3131
return backend.runWebGPUProgram(program, [x, alpha], 'float32');
3232
}
3333

tfjs-backend-webgpu/src/kernels/binary_op_shared_webgpu.ts

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
* =============================================================================
1616
*/
1717

18-
import {backend_util} from '@tensorflow/tfjs-core';
19-
2018
import {getMainHeaderAndGlobalIndexString} from '../shader_preprocessor';
2119
import {computeDispatch, flatDispatchLayout} from '../webgpu_util';
2220
import {BinaryOpType, getBinaryOpString} from './binary_op_util';
@@ -29,66 +27,57 @@ export class BinaryOpSharedProgram implements WebGPUProgram {
2927
dispatchLayout: {x: number[]};
3028
dispatch: [number, number, number];
3129
variableNames = ['A', 'B'];
32-
workPerThread: number;
30+
workPerThread = 4;
3331
workGroupSize: [number, number, number];
3432
useSharedMemoryWithB: boolean;
35-
lastDimensionSize: number;
33+
isScater: boolean;
3634
op: BinaryOpType;
3735
size = true;
3836

3937
constructor(
40-
op: BinaryOpType, aShape: number[], bShape: number[],
41-
useSharedMemoryWithB: boolean) {
38+
op: BinaryOpType, outputShape: number[], useSharedMemoryWithB: boolean,
39+
isScater: boolean) {
4240
// This is an experimental value when using shared memory.
4341
// Note that the maximum of workgroup X dimension is 256.
4442
const workGroupSizeX = 256;
4543
this.workGroupSize = [workGroupSizeX, 1, 1];
46-
this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
44+
this.outputShape = outputShape;
4745
this.dispatchLayout = flatDispatchLayout(this.outputShape);
48-
this.lastDimensionSize = useSharedMemoryWithB ? bShape[0] : aShape[0];
49-
if (this.lastDimensionSize < 256) {
50-
this.workPerThread = 1;
51-
} else if (this.lastDimensionSize < 512) {
52-
this.workPerThread = 2;
53-
} else {
54-
this.workPerThread = 4;
55-
}
46+
this.isScater = isScater;
5647
this.dispatch = computeDispatch(
5748
this.dispatchLayout, this.outputShape, this.workGroupSize,
5849
[this.workPerThread, 1, 1]);
5950

6051
this.useSharedMemoryWithB = useSharedMemoryWithB;
6152
this.op = op;
62-
// this.lastDimensionSize is used as sharedBuf array size, so can not be
63-
// used as uniform.
64-
this.shaderKey = `binaryShared_${op}_${this.lastDimensionSize}_${
65-
this.useSharedMemoryWithB}`;
53+
this.shaderKey =
54+
`binaryShared_${op}_${this.useSharedMemoryWithB}_${isScater}`;
6655
}
6756

6857
getUserCode(): string {
69-
const sharedIndexSnippet = this.lastDimensionSize > 1 ?
70-
`coords[${this.outputShape.length - 1}]` :
71-
'0';
58+
const sharedIndexSnippet =
59+
this.isScater ? '0' : `coords[${this.outputShape.length - 1}]`;
7260
const accessDataSnippet = this.useSharedMemoryWithB ?
73-
`let a = getAAtOutCoordsByCoords(coords);
61+
`let a = getAAtOutCoordsByGlobalIndex(flatIndex);
7462
let b = sharedBuf[${sharedIndexSnippet}];` :
7563
`let a = sharedBuf[${sharedIndexSnippet}];
76-
let b = getBAtOutCoordsByCoords(coords);`;
64+
let b = getBAtOutCoordsByGlobalIndex(flatIndex);`;
7765

78-
const opStr = getBinaryOpString(this.op, false);
7966
const userCode = `
8067
fn binaryOperation(a : f32, b : f32) -> f32 {
81-
${opStr}
68+
${getBinaryOpString(this.op, false)}
8269
}
83-
var<workgroup> sharedBuf : array<f32, ${this.lastDimensionSize}>;
70+
71+
var<workgroup> sharedBuf : array<f32, ${
72+
this.workGroupSize[0] * this.workPerThread}>;
8473
${getMainHeaderAndGlobalIndexString()}
8574
8675
// Fill in the shared memory buffer. Here we need a loop to make sure
8776
// that all data in A|B are uploaded when |sharedMemorySize| is larger
8877
// than work group size.
8978
for(var localIndex = i32(localId.x); localIndex < ${
90-
this.lastDimensionSize}; localIndex = localIndex + ${
91-
this.workGroupSize[0]}) {
79+
this.useSharedMemoryWithB ? 'uniforms.bShape' : 'uniforms.aShape'};
80+
localIndex = localIndex + ${this.workGroupSize[0]}) {
9281
sharedBuf[localIndex] = f32(${
9382
this.useSharedMemoryWithB ? 'B' : 'A'}.numbers[localIndex]);
9483
}

tfjs-backend-webgpu/src/kernels/binary_op_vec4_webgpu.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
* =============================================================================
1616
*/
1717

18-
import {backend_util} from '@tensorflow/tfjs-core';
1918
import {getMainHeaderAndGlobalIndexString} from '../shader_preprocessor';
2019
import {computeDispatch, flatDispatchLayout} from '../webgpu_util';
2120
import {BinaryOpType, getBinaryOpString} from './binary_op_util';
@@ -33,13 +32,12 @@ export class BinaryOpVec4Program implements WebGPUProgram {
3332
isVec4 = true;
3433
op: BinaryOpType;
3534
size = true;
36-
fitShape: boolean;
3735

38-
constructor(op: BinaryOpType, aShape: number[], bShape: number[]) {
36+
constructor(op: BinaryOpType, outputShape: number[]) {
3937
// TODO(jiajia.qin@intel.com): Heuristically select a good work group size.
4038
const workGroupSizeX = 128;
4139
this.workGroupSize = [workGroupSizeX, 1, 1];
42-
this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
40+
this.outputShape = outputShape;
4341
this.dispatchLayout = flatDispatchLayout(this.outputShape);
4442
this.dispatch = computeDispatch(
4543
this.dispatchLayout, this.outputShape, this.workGroupSize,

tfjs-backend-webgpu/src/kernels/binary_op_webgpu.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
* =============================================================================
1616
*/
1717

18-
import {backend_util} from '@tensorflow/tfjs-core';
1918
import {getMainHeaderAndGlobalIndexString} from '../shader_preprocessor';
2019
import {computeDispatch, flatDispatchLayout} from '../webgpu_util';
2120
import {BinaryOpType, getBinaryOpString} from './binary_op_util';
@@ -32,11 +31,11 @@ export class BinaryOpProgram implements WebGPUProgram {
3231
op: BinaryOpType;
3332
size = true;
3433

35-
constructor(op: BinaryOpType, aShape: number[], bShape: number[]) {
34+
constructor(op: BinaryOpType, outputShape: number[]) {
3635
// TODO(jiajia.qin@intel.com): Heuristically select a good work group size.
3736
const workGroupSizeX = 128;
3837
this.workGroupSize = [workGroupSizeX, 1, 1];
39-
this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
38+
this.outputShape = outputShape;
4039
this.dispatchLayout = flatDispatchLayout(this.outputShape);
4140

4241
this.dispatch = computeDispatch(

tfjs-backend-webgpu/src/kernels/binary_ops.ts

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,30 @@
1515
* =============================================================================
1616
*/
1717

18-
import {util} from '@tensorflow/tfjs-core';
18+
import {backend_util, util} from '@tensorflow/tfjs-core';
19+
1920
import {BinaryOpSharedProgram} from './binary_op_shared_webgpu';
21+
import {BinaryOpType} from './binary_op_util';
2022
import {BinaryOpVec4Program} from './binary_op_vec4_webgpu';
2123
import {BinaryOpProgram} from './binary_op_webgpu';
22-
import {BinaryOpType} from './binary_op_util';
2324

2425
export function getBinaryProgram(
2526
op: BinaryOpType, aShape: number[], bShape: number[]) {
27+
const outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
2628
const useVec4 =
2729
util.arraysEqual(aShape, bShape) && util.sizeFromShape(aShape) % 4 === 0;
2830
if (useVec4) {
29-
return new BinaryOpVec4Program(op, aShape, bShape);
31+
return new BinaryOpVec4Program(op, outputShape);
3032
}
3133
const useSharedMemoryWithA =
32-
aShape.length === 1 && bShape.length > 1 && aShape[0] < 1024;
34+
aShape.length === 1 && bShape.length > 1 && aShape[0] < 512;
3335
const useSharedMemoryWithB =
34-
bShape.length === 1 && aShape.length > 1 && bShape[0] < 1024;
36+
bShape.length === 1 && aShape.length > 1 && bShape[0] < 512;
3537
if (useSharedMemoryWithA || useSharedMemoryWithB) {
36-
return new BinaryOpSharedProgram(op, aShape, bShape, useSharedMemoryWithB);
38+
const isScater = useSharedMemoryWithB ? bShape[0] === 1 : aShape[0] === 1;
39+
return new BinaryOpSharedProgram(
40+
op, outputShape, useSharedMemoryWithB, isScater);
3741
} else {
38-
return new BinaryOpProgram(op, aShape, bShape);
42+
return new BinaryOpProgram(op, outputShape);
3943
}
4044
}

0 commit comments

Comments
 (0)