15
15
* =============================================================================
16
16
*/
17
17
18
- import { backend_util } from '@tensorflow/tfjs-core' ;
19
-
20
18
import { getMainHeaderAndGlobalIndexString } from '../shader_preprocessor' ;
21
19
import { computeDispatch , flatDispatchLayout } from '../webgpu_util' ;
22
20
import { BinaryOpType , getBinaryOpString } from './binary_op_util' ;
@@ -29,66 +27,57 @@ export class BinaryOpSharedProgram implements WebGPUProgram {
29
27
dispatchLayout : { x : number [ ] } ;
30
28
dispatch : [ number , number , number ] ;
31
29
variableNames = [ 'A' , 'B' ] ;
32
- workPerThread : number ;
30
+ workPerThread = 4 ;
33
31
workGroupSize : [ number , number , number ] ;
34
32
useSharedMemoryWithB : boolean ;
35
- lastDimensionSize : number ;
33
+ isScater : boolean ;
36
34
op : BinaryOpType ;
37
35
size = true ;
38
36
39
37
constructor (
40
- op : BinaryOpType , aShape : number [ ] , bShape : number [ ] ,
41
- useSharedMemoryWithB : boolean ) {
38
+ op : BinaryOpType , outputShape : number [ ] , useSharedMemoryWithB : boolean ,
39
+ isScater : boolean ) {
42
40
// This is an experimental value when using shared memory.
43
41
// Note that the maximum of workgroup X dimension is 256.
44
42
const workGroupSizeX = 256 ;
45
43
this . workGroupSize = [ workGroupSizeX , 1 , 1 ] ;
46
- this . outputShape = backend_util . assertAndGetBroadcastShape ( aShape , bShape ) ;
44
+ this . outputShape = outputShape ;
47
45
this . dispatchLayout = flatDispatchLayout ( this . outputShape ) ;
48
- this . lastDimensionSize = useSharedMemoryWithB ? bShape [ 0 ] : aShape [ 0 ] ;
49
- if ( this . lastDimensionSize < 256 ) {
50
- this . workPerThread = 1 ;
51
- } else if ( this . lastDimensionSize < 512 ) {
52
- this . workPerThread = 2 ;
53
- } else {
54
- this . workPerThread = 4 ;
55
- }
46
+ this . isScater = isScater ;
56
47
this . dispatch = computeDispatch (
57
48
this . dispatchLayout , this . outputShape , this . workGroupSize ,
58
49
[ this . workPerThread , 1 , 1 ] ) ;
59
50
60
51
this . useSharedMemoryWithB = useSharedMemoryWithB ;
61
52
this . op = op ;
62
- // this.lastDimensionSize is used as sharedBuf array size, so can not be
63
- // used as uniform.
64
- this . shaderKey = `binaryShared_${ op } _${ this . lastDimensionSize } _${
65
- this . useSharedMemoryWithB } `;
53
+ this . shaderKey =
54
+ `binaryShared_${ op } _${ this . useSharedMemoryWithB } _${ isScater } ` ;
66
55
}
67
56
68
57
getUserCode ( ) : string {
69
- const sharedIndexSnippet = this . lastDimensionSize > 1 ?
70
- `coords[${ this . outputShape . length - 1 } ]` :
71
- '0' ;
58
+ const sharedIndexSnippet =
59
+ this . isScater ? '0' : `coords[${ this . outputShape . length - 1 } ]` ;
72
60
const accessDataSnippet = this . useSharedMemoryWithB ?
73
- `let a = getAAtOutCoordsByCoords(coords );
61
+ `let a = getAAtOutCoordsByGlobalIndex(flatIndex );
74
62
let b = sharedBuf[${ sharedIndexSnippet } ];` :
75
63
`let a = sharedBuf[${ sharedIndexSnippet } ];
76
- let b = getBAtOutCoordsByCoords(coords );` ;
64
+ let b = getBAtOutCoordsByGlobalIndex(flatIndex );` ;
77
65
78
- const opStr = getBinaryOpString ( this . op , false ) ;
79
66
const userCode = `
80
67
fn binaryOperation(a : f32, b : f32) -> f32 {
81
- ${ opStr }
68
+ ${ getBinaryOpString ( this . op , false ) }
82
69
}
83
- var<workgroup> sharedBuf : array<f32, ${ this . lastDimensionSize } >;
70
+
71
+ var<workgroup> sharedBuf : array<f32, ${
72
+ this . workGroupSize [ 0 ] * this . workPerThread } >;
84
73
${ getMainHeaderAndGlobalIndexString ( ) }
85
74
86
75
// Fill in the shared memory buffer. Here we need a loop to make sure
87
76
// that all data in A|B are uploaded when |sharedMemorySize| is larger
88
77
// than work group size.
89
78
for(var localIndex = i32(localId.x); localIndex < ${
90
- this . lastDimensionSize } ; localIndex = localIndex + ${
91
- this . workGroupSize [ 0 ] } ) {
79
+ this . useSharedMemoryWithB ? 'uniforms.bShape' : 'uniforms.aShape' } ;
80
+ localIndex = localIndex + ${ this . workGroupSize [ 0 ] } ) {
92
81
sharedBuf[localIndex] = f32(${
93
82
this . useSharedMemoryWithB ? 'B' : 'A' } .numbers[localIndex]);
94
83
}
0 commit comments