Skip to content

[webgpu] Use different global index for flat and non-flat dispatch #5706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions tfjs-backend-webgpu/src/kernels/argminmax_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import {backend_util, util} from '@tensorflow/tfjs-core';

import {getCoordsDataType, getGlobalIndexString, getMainHeaderString} from '../shader_preprocessor';
import {getCoordsDataType, getMainHeaderString} from '../shader_preprocessor';
import {computeDispatch} from '../webgpu_util';

import {WebGPUProgram} from './webgpu_program';
Expand Down Expand Up @@ -139,9 +139,9 @@ export class ArgMinMaxProgram implements WebGPUProgram {
// add back the index along the reduced dimension to |outputCoords|.
// This function outputs the offset to the first value along
// |axis| and the stride to get the next value of the input along |axis|.
fn getInputCoordInfo(globalId : vec3<u32>, globalIndex : i32) -> vec2<i32>{
fn getInputCoordInfo(globalId : vec3<u32>) -> vec2<i32>{
let outputCoords : ${
outputCoordsType} = getOutputCoords(globalId, globalIndex);
outputCoordsType} = getOutputCoords(globalId, i32(globalId.x));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second parameter seems like meaningless for non-flat dispatch. You already passed globalId. But you still have to pass globalId.x. I prefer that you refactor getXXX method to remove the dependency of globalIndex. Similar for other places.

var i = ${this.outputShape.length - 1};

var stride = 1;
Expand All @@ -168,8 +168,7 @@ export class ArgMinMaxProgram implements WebGPUProgram {
}

${getMainHeaderString()} {
${getGlobalIndexString()}
let coordInfo = getInputCoordInfo(globalId, index);
let coordInfo = getInputCoordInfo(globalId);

var bestIndex = 0;
var bestValue = x.numbers[getInputIndex(coordInfo, bestIndex)];
Expand Down
11 changes: 5 additions & 6 deletions tfjs-backend-webgpu/src/kernels/depthwise_conv2d_3x3_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import {backend_util, util} from '@tensorflow/tfjs-core';

import {getGlobalIndexString, getMainHeaderString} from '../shader_preprocessor';
import {getMainHeaderString} from '../shader_preprocessor';
import {computeDispatch} from '../webgpu_util';

import {mapActivationToShaderProgram} from './activation_util';
Expand Down Expand Up @@ -72,20 +72,20 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
mapActivationToShaderProgram(this.activation, this.isVec4);
if (this.hasPreluActivation) {
activationSnippet =
`fn activation(a : vec4<f32>, globalId : vec3<u32>, globalIndex : i32) -> vec4<f32> {
let b = getPreluActivationWeightsAtOutCoordsByGlobalId(globalId, globalIndex);
`fn activation(a : vec4<f32>, globalId : vec3<u32>) -> vec4<f32> {
let b = getPreluActivationWeightsAtOutCoordsByGlobalId(globalId, i32(globalId.x));
${activationOp}
}`;
} else {
activationSnippet = `
fn activation(a : vec4<f32>, globalId : vec3<u32>, globalIndex : i32) -> vec4<f32> {
fn activation(a : vec4<f32>, globalId : vec3<u32>) -> vec4<f32> {
${activationOp}
}
`;
}

applyActivationSnippet =
`dotProd[i] = activation(dotProd[i], globalId, index);`;
`dotProd[i] = activation(dotProd[i], globalId);`;
}

const addBiasSnippet = this.addBias ?
Expand All @@ -96,7 +96,6 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
${activationSnippet}

${getMainHeaderString()} {
${getGlobalIndexString()}
let batch = 0;
let r = i32(globalId.x);
let c = i32(globalId.y) * 4;
Expand Down
9 changes: 4 additions & 5 deletions tfjs-backend-webgpu/src/kernels/reduce_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

import {backend_util, DataType} from '@tensorflow/tfjs-core';
import {getGlobalIndexString, getMainHeaderString} from '../shader_preprocessor';
import {getMainHeaderString} from '../shader_preprocessor';
import {computeDispatch} from '../webgpu_util';

import {WebGPUProgram} from './webgpu_program';
Expand Down Expand Up @@ -121,17 +121,16 @@ export class ReduceProgram implements WebGPUProgram {
}
let WorkGroupSize = ${this.workGroupSize[0]};
${reduceInSharedMemory ? sharedMemorySnippet : ''}
fn getOffset(globalId : vec3<u32>, index : i32) -> i32 {
let outputCoords = getOutputCoords(globalId, index);
fn getOffset(globalId : vec3<u32>) -> i32 {
let outputCoords = getOutputCoords(globalId, i32(globalId.x));
let offset = ${
this.outputShape.length === 1 ?
'outputCoords' :
'outputCoords[0]'} * uniforms.reduceSize;
return offset;
}
${getMainHeaderString()} {
${getGlobalIndexString()}
let offset= getOffset(globalId, index);
let offset = getOffset(globalId);
var bestValue = ${initValue};
let Length = uniforms.reduceSize;
let WorkPerThread = DIV_CEIL(Length, WorkGroupSize);
Expand Down
3 changes: 1 addition & 2 deletions tfjs-backend-webgpu/src/kernels/transpose_shared_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {getGlobalIndexString, getMainHeaderString} from '../shader_preprocessor';
import {getMainHeaderString} from '../shader_preprocessor';
import {computeDispatch} from '../webgpu_util';

import {WebGPUProgram} from './webgpu_program';
Expand Down Expand Up @@ -48,7 +48,6 @@ export class TransposeSharedProgram implements WebGPUProgram {
var<workgroup> tile : array<array<f32, ${this.workGroupSize[0] + 1}>, ${
this.workGroupSize[0]}>;
${getMainHeaderString()} {
${getGlobalIndexString()}
let workGroupID = (globalId - localId)/vec3<u32>(${
this.workGroupSize[0]}u, ${this.workGroupSize[1]}u, ${
this.workGroupSize[2]}u);
Expand Down
8 changes: 1 addition & 7 deletions tfjs-backend-webgpu/src/shader_preprocessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ export function getWorkGroupSizeString(): string {
}

export function getGlobalIndexString(): string {
return `
let index = getGlobalIndex(globalId, localId);
`;
return 'let index = getGlobalIndex(globalId, localId);';
}

export function getMainHeaderString() {
Expand Down Expand Up @@ -278,11 +276,7 @@ const SAMPLING_SNIPPETS = `
f32(shape.y) * f32(shape.z) * f32(shape.w), f32(shape.z) * f32(shape.w), f32(shape.w), 1.0)));
}

// Only used when the y/z dimension of workgroup size is 1.
fn getGlobalIndex(globalId : vec3<u32>, localId : vec3<u32>) -> i32 {
if (uniforms.dispatchSize.y == 1u && uniforms.dispatchSize.z == 1u) {
return i32(globalId.x);
}
let localInvocationIndex = localId.z * workGroupSizeX * workGroupSizeY +
localId.y * workGroupSizeX + localId.x;
let workGroupID = (globalId - localId)/vec3<u32>(
Expand Down