forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcum_webgpu.ts
117 lines (109 loc) · 3.88 KB
/
cum_webgpu.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/**
* @license
* Copyright 2022 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';
export enum CumOpType {
Prod = '*',
Sum = '+',
}
export class CumProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
dispatchLayout: {x: number[]};
dispatch: [number, number, number];
variableNames = ['x'];
workgroupSize: [number, number, number];
// pow(i32, i32) is not supported, use pow(f32, f32) instead.
uniforms = 'index : f32,';
size = true;
exclusive: boolean;
reverse: boolean;
op: CumOpType;
constructor(
op: CumOpType, shape: number[], exclusive: boolean, reverse: boolean) {
this.workgroupSize = [128, 1, 1];
this.outputShape = shape;
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);
this.exclusive = exclusive;
this.reverse = reverse;
this.op = op;
this.shaderKey = `cum_${this.op}_${this.exclusive}_${this.reverse}`;
}
getUserCode(): string {
const rank = this.outputShape.length;
const initVal = this.op === CumOpType.Prod ? '1.0' : '0.0';
const val = this.exclusive ? initVal :
`getX(${getCoords(rank, 'coords', this.op)})`;
const length = this.outputShape[this.outputShape.length - 1];
let condition = '';
let idxString = '';
// When exclusive is set, the cum op becomes roll op that copies the
// value from the previous index based on the direction specified by the
// reverse flag.
if (this.exclusive) {
condition = this.reverse ? `end != ${length - 1}` : 'end != 0';
idxString = this.reverse ? 'end + 1' : 'end - 1';
} else {
condition = this.reverse ? `end + pow2 < ${length}` : 'end >= pow2';
idxString = (this.reverse ? 'end + pow2' : 'end - pow2');
}
return `
${main('index')} {
if (index < uniforms.size) {
var coords = getCoordsFromIndex(index);
let end = ${getFinalCoord(rank, 'coords', this.op)};
var val = ${val};
let pow2 = i32(pow(2.0, uniforms.index));
if (${condition}) {
let idx = ${idxString};
${getFinalCoord(rank, 'coords', this.op)} = idx;
val ${this.op}= getX(${getCoords(rank, 'coords', this.op)});
}
setOutputAtIndex(index, val);
}
}
`;
}
}
function getCoords(rank: number, name: string, op: CumOpType): string {
if (rank === 1) {
return `${name}`;
} else if (rank === 2) {
return `${name}.x, ${name}.y`;
} else if (rank === 3) {
return `${name}.x, ${name}.y, ${name}.z`;
} else if (rank === 4) {
return `${name}.x, ${name}.y, ${name}.z, ${name}.w`;
} else {
throw Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
}
}
function getFinalCoord(rank: number, name: string, op: CumOpType): string {
if (rank === 1) {
return `${name}`;
} else if (rank === 2) {
return `${name}.y`;
} else if (rank === 3) {
return `${name}.z`;
} else if (rank === 4) {
return `${name}.w`;
} else {
throw Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
}
}