forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresize_nearest_neighbor_webgpu.ts
90 lines (76 loc) · 3.21 KB
/
resize_nearest_neighbor_webgpu.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';
export class ResizeNearestNeighborProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
dispatchLayout: {x: number[]};
dispatch: [number, number, number];
variableNames = ['x'];
uniforms = 'adjustHeightWidth : vec2<f32>, roundBase : f32,';
workgroupSize: [number, number, number] = [64, 1, 1];
halfPixelCenters: boolean;
size = true;
constructor(
inputShape: [number, number, number, number], newHeight: number,
newWidth: number, halfPixelCenters: boolean) {
this.outputShape = [inputShape[0], newHeight, newWidth, inputShape[3]];
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);
this.halfPixelCenters = halfPixelCenters;
this.shaderKey = `resizeNearest_${halfPixelCenters}`;
}
getUserCode(): string {
let sourceFracIndexRC: string;
if (this.halfPixelCenters) {
sourceFracIndexRC =
`max((vec2<f32>(rc) + vec2<f32>(0.5)) * effectiveInputOverOutputRatioRC` +
`, vec2<f32>(0.0))`;
} else {
sourceFracIndexRC = `vec2<f32>(rc) * effectiveInputOverOutputRatioRC`;
}
const userCode = `
${main('index')} {
if (index < uniforms.size) {
let coords = getCoordsFromIndex(index);
let b = coords[0];
let d = coords[3];
let rc = coords.yz;
let effectiveInSize = vec2<f32>(
f32(uniforms.xShape.y) - uniforms.adjustHeightWidth[0],
f32(uniforms.xShape.z) - uniforms.adjustHeightWidth[1]);
let effectiveOutSize = vec2<f32>(
f32(uniforms.outShape.y) - uniforms.adjustHeightWidth[0],
f32(uniforms.outShape.z) - uniforms.adjustHeightWidth[1]);
let effectiveInputOverOutputRatioRC =
effectiveInSize / effectiveOutSize;
// Fractional source index
let sourceFracIndexRC = ${sourceFracIndexRC};
// Compute the coordinators of nearest neighbor point.
let inputShapeRC = vec2<f32>(f32(uniforms.xShape.y), f32(uniforms.xShape.z));
let sourceNearestRC = vec2<i32>(
min(inputShapeRC - 1.0, floor(sourceFracIndexRC + uniforms.roundBase)));
let newValue = getX(b, sourceNearestRC.x, sourceNearestRC.y, d);
setOutputAtIndex(index, newValue);
}
}
`;
return userCode;
}
}