Skip to content

Commit dc08875

Browse files
committed
Do not use getGlobalIndexString when non flat
1 parent fce8908 commit dc08875

File tree

6 files changed

+17
-26
lines changed

6 files changed

+17
-26
lines changed

tfjs-backend-webgpu/src/kernels/argminmax_webgpu.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import {backend_util, util} from '@tensorflow/tfjs-core';
1919

20-
import {getCoordsDataType, getGlobalIndexString, getMainHeaderString} from '../shader_preprocessor';
20+
import {getCoordsDataType, getMainHeaderString} from '../shader_preprocessor';
2121
import {computeDispatch} from '../webgpu_util';
2222

2323
import {WebGPUProgram} from './webgpu_program';
@@ -139,9 +139,9 @@ export class ArgMinMaxProgram implements WebGPUProgram {
139139
// add back the index along the reduced dimension to |outputCoords|.
140140
// This function outputs the offset to the first value along
141141
// |axis| and the stride to get the next value of the input along |axis|.
142-
fn getInputCoordInfo(globalId : vec3<u32>, globalIndex : i32) -> vec2<i32>{
142+
fn getInputCoordInfo(globalId : vec3<u32>) -> vec2<i32>{
143143
let outputCoords : ${
144-
outputCoordsType} = getOutputCoords(globalId, globalIndex);
144+
outputCoordsType} = getOutputCoords(globalId, i32(globalId.x));
145145
var i = ${this.outputShape.length - 1};
146146
147147
var stride = 1;
@@ -168,8 +168,7 @@ export class ArgMinMaxProgram implements WebGPUProgram {
168168
}
169169
170170
${getMainHeaderString()} {
171-
${getGlobalIndexString(true)}
172-
let coordInfo = getInputCoordInfo(globalId, index);
171+
let coordInfo = getInputCoordInfo(globalId);
173172
174173
var bestIndex = 0;
175174
var bestValue = x.numbers[getInputIndex(coordInfo, bestIndex)];

tfjs-backend-webgpu/src/kernels/depthwise_conv2d_3x3_webgpu.ts

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import {backend_util, util} from '@tensorflow/tfjs-core';
1919

20-
import {getGlobalIndexString, getMainHeaderString} from '../shader_preprocessor';
20+
import {getMainHeaderString} from '../shader_preprocessor';
2121
import {computeDispatch} from '../webgpu_util';
2222

2323
import {mapActivationToShaderProgram} from './activation_util';
@@ -72,20 +72,20 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
7272
mapActivationToShaderProgram(this.activation, this.isVec4);
7373
if (this.hasPreluActivation) {
7474
activationSnippet =
75-
`fn activation(a : vec4<f32>, globalId : vec3<u32>, globalIndex : i32) -> vec4<f32> {
76-
let b = getPreluActivationWeightsAtOutCoordsByGlobalId(globalId, globalIndex);
75+
`fn activation(a : vec4<f32>, globalId : vec3<u32>) -> vec4<f32> {
76+
let b = getPreluActivationWeightsAtOutCoordsByGlobalId(globalId, i32(globalId.x));
7777
${activationOp}
7878
}`;
7979
} else {
8080
activationSnippet = `
81-
fn activation(a : vec4<f32>, globalId : vec3<u32>, globalIndex : i32) -> vec4<f32> {
81+
fn activation(a : vec4<f32>, globalId : vec3<u32>) -> vec4<f32> {
8282
${activationOp}
8383
}
8484
`;
8585
}
8686

8787
applyActivationSnippet =
88-
`dotProd[i] = activation(dotProd[i], globalId, index);`;
88+
`dotProd[i] = activation(dotProd[i], globalId);`;
8989
}
9090

9191
const addBiasSnippet = this.addBias ?
@@ -96,7 +96,6 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
9696
${activationSnippet}
9797
9898
${getMainHeaderString()} {
99-
${getGlobalIndexString(true)}
10099
let batch = 0;
101100
let r = i32(globalId.x);
102101
let c = i32(globalId.y) * 4;

tfjs-backend-webgpu/src/kernels/reduce_webgpu.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717

1818
import {backend_util, DataType} from '@tensorflow/tfjs-core';
19-
import {getGlobalIndexString, getMainHeaderString} from '../shader_preprocessor';
19+
import {getMainHeaderString} from '../shader_preprocessor';
2020
import {computeDispatch} from '../webgpu_util';
2121

2222
import {WebGPUProgram} from './webgpu_program';
@@ -121,17 +121,16 @@ export class ReduceProgram implements WebGPUProgram {
121121
}
122122
let WorkGroupSize = ${this.workGroupSize[0]};
123123
${reduceInSharedMemory ? sharedMemorySnippet : ''}
124-
fn getOffset(globalId : vec3<u32>, index : i32) -> i32 {
125-
let outputCoords = getOutputCoords(globalId, index);
124+
fn getOffset(globalId : vec3<u32>) -> i32 {
125+
let outputCoords = getOutputCoords(globalId, i32(globalId.x));
126126
let offset = ${
127127
this.outputShape.length === 1 ?
128128
'outputCoords' :
129129
'outputCoords[0]'} * uniforms.reduceSize;
130130
return offset;
131131
}
132132
${getMainHeaderString()} {
133-
${getGlobalIndexString(true)}
134-
let offset= getOffset(globalId, index);
133+
let offset = getOffset(globalId);
135134
var bestValue = ${initValue};
136135
let Length = uniforms.reduceSize;
137136
let WorkPerThread = DIV_CEIL(Length, WorkGroupSize);

tfjs-backend-webgpu/src/kernels/slice_webgpu.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ export class SliceProgram implements WebGPUProgram {
6565

6666
const userCode = `
6767
${getMainHeaderString()} {
68-
${getGlobalIndexString(true)}
68+
${getGlobalIndexString()}
6969
if (index < uniforms.size) {
7070
var sourceLoc : ${dtype};
7171
let coords = getOutputCoords(globalId, index);

tfjs-backend-webgpu/src/kernels/transpose_shared_webgpu.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* =============================================================================
1616
*/
1717

18-
import {getGlobalIndexString, getMainHeaderString} from '../shader_preprocessor';
18+
import {getMainHeaderString} from '../shader_preprocessor';
1919
import {computeDispatch} from '../webgpu_util';
2020

2121
import {WebGPUProgram} from './webgpu_program';
@@ -48,7 +48,6 @@ export class TransposeSharedProgram implements WebGPUProgram {
4848
var<workgroup> tile : array<array<f32, ${this.workGroupSize[0] + 1}>, ${
4949
this.workGroupSize[0]}>;
5050
${getMainHeaderString()} {
51-
${getGlobalIndexString(true)}
5251
let workGroupID = (globalId - localId)/vec3<u32>(${
5352
this.workGroupSize[0]}u, ${this.workGroupSize[1]}u, ${
5453
this.workGroupSize[2]}u);

tfjs-backend-webgpu/src/shader_preprocessor.ts

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,8 @@ export function getWorkGroupSizeString(): string {
7171
`;
7272
}
7373

74-
export function getGlobalIndexString(nonFlatDispatch = false): string {
75-
if (nonFlatDispatch) {
76-
return 'let index = getGlobalIndex(globalId, localId);';
77-
} else {
78-
// Only used when the y/z dimension of workgroup size is 1.
79-
return 'let index = i32(globalId.x);';
80-
}
74+
export function getGlobalIndexString(): string {
75+
return 'let index = getGlobalIndex(globalId, localId);';
8176
}
8277

8378
export function getMainHeaderString() {

0 commit comments

Comments
 (0)