Skip to content

Commit 95e44a5

Browse files
authored
add experimental webgl support for RN platform (tensorflow#1844)
* fix tests add webgl backend prototype * webgl path for ios * save * avoid shadowing global functions in clip * upgrade tfjs version * remove unused constant in shader compiler * Merge branch 'master' into rn-webgl * code review fixes improve float texture download detection to include half float textures. * save * save * remove warning as it makes tests noisy consider refactoring this in another PR. * Merge branch 'master' into rn-webgl * revert settings update * add comment * save * fix fetching of binary files * save * save * have tests spy on platform.fetch * Merge branch 'master' into rn-webgl * code review fixes * save * Merge branch 'master' into rn-webgl * fix lint errors * code review fixes
1 parent 03c205e commit 95e44a5

29 files changed

+505
-251
lines changed

.npmignore

+2
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ rollup.config.js
2222
tsconfig.json
2323
.yalc/
2424
yalc.lock
25+
tfjs-react-native/
26+
tfjs-backend-nodegl/

src/backends/webgl/clip_gpu.ts

+5-5
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ export class ClipProgram implements GPGPUProgram {
3030
constructor(aShape: number[]) {
3131
this.outputShape = aShape;
3232
this.userCode = `
33-
uniform float min;
34-
uniform float max;
33+
uniform float minVal;
34+
uniform float maxVal;
3535
3636
void main() {
3737
float value = getAAtOutCoords();
@@ -40,16 +40,16 @@ export class ClipProgram implements GPGPUProgram {
4040
return;
4141
}
4242
43-
setOutput(clamp(value, min, max));
43+
setOutput(clamp(value, minVal, maxVal));
4444
}
4545
`;
4646
}
4747

4848
getCustomSetupFunc(min: number, max: number) {
4949
return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {
5050
if (this.minLoc == null) {
51-
this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'min');
52-
this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'max');
51+
this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');
52+
this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');
5353
}
5454
gpgpu.gl.uniform1f(this.minLoc, min);
5555
gpgpu.gl.uniform1f(this.maxLoc, max);

src/backends/webgl/clip_packed_gpu.ts

+5-5
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ export class ClipPackedProgram implements GPGPUProgram {
3131
constructor(aShape: number[]) {
3232
this.outputShape = aShape;
3333
this.userCode = `
34-
uniform float min;
35-
uniform float max;
34+
uniform float minVal;
35+
uniform float maxVal;
3636
3737
void main() {
3838
vec4 value = getAAtOutCoords();
@@ -42,16 +42,16 @@ export class ClipPackedProgram implements GPGPUProgram {
4242
return;
4343
}
4444
45-
setOutput(clamp(value, vec4(min), vec4(max)));
45+
setOutput(clamp(value, vec4(minVal), vec4(maxVal)));
4646
}
4747
`;
4848
}
4949

5050
getCustomSetupFunc(min: number, max: number) {
5151
return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {
5252
if (this.minLoc == null) {
53-
this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'min');
54-
this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'max');
53+
this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');
54+
this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');
5555
}
5656
gpgpu.gl.uniform1f(this.minLoc, min);
5757
gpgpu.gl.uniform1f(this.maxLoc, max);

src/backends/webgl/glsl_version.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ export function getGlslDifferences(): GLSL {
5555
return (val > 0. || val < 0. || val == 0.) ? false : true;
5656
}
5757
`;
58-
defineSpecialInf = `
59-
const float INFINITY = uintBitsToFloat(uint(0x7f800000));
60-
`;
58+
// In webgl 2 we do not need to specify a custom isinf so there is no
59+
// need for a special INFINITY constant.
60+
defineSpecialInf = ``;
6161
defineRound = `
6262
#define round(value) newRound(value)
6363
int newRound(float value) {

src/backends/webgl/gpgpu_context.ts

+13-4
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import * as util from '../../util';
2121

2222
import {getWebGLContext, setWebGLContext} from './canvas_util';
2323
import * as gpgpu_util from './gpgpu_util';
24-
import {TextureConfig} from './gpgpu_util';
2524
import * as tex_util from './tex_util';
25+
import {TextureConfig} from './tex_util';
2626
import {WebGL1DisjointQueryTimerExtension, WebGL2DisjointQueryTimerExtension} from './webgl_types';
2727
import * as webgl_util from './webgl_util';
2828

@@ -70,16 +70,25 @@ export class GPGPUContext {
7070
this.gl.getExtension('EXT_color_buffer_half_float');
7171
}
7272
} else {
73-
this.colorBufferFloatExtension = webgl_util.getExtensionOrThrow(
74-
this.gl, this.debug, 'EXT_color_buffer_float');
73+
const COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float';
74+
const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
75+
if (webgl_util.hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {
76+
this.colorBufferFloatExtension =
77+
this.gl.getExtension(COLOR_BUFFER_FLOAT);
78+
} else if (webgl_util.hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
79+
this.colorBufferHalfFloatExtension =
80+
this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
81+
} else {
82+
throw new Error('GL context does not support color renderable floats');
83+
}
7584
}
7685

7786
this.vertexBuffer = gpgpu_util.createVertexBuffer(this.gl, this.debug);
7887
this.indexBuffer = gpgpu_util.createIndexBuffer(this.gl, this.debug);
7988
this.framebuffer = webgl_util.createFramebuffer(this.gl, this.debug);
8089

8190
this.textureConfig =
82-
gpgpu_util.getTextureConfig(this.gl, this.textureHalfFloatExtension);
91+
tex_util.getTextureConfig(this.gl, this.textureHalfFloatExtension);
8392
}
8493

8594
private get debug(): boolean {

src/backends/webgl/gpgpu_util.ts

+1-71
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,13 @@
1515
* =============================================================================
1616
*/
1717

18-
import {ENV} from '../../environment';
1918
import {PixelData, TypedArray} from '../../types';
2019

2120
import {getGlslDifferences} from './glsl_version';
2221
import * as tex_util from './tex_util';
22+
import {TextureConfig} from './tex_util';
2323
import * as webgl_util from './webgl_util';
2424

25-
export interface TextureConfig {
26-
internalFormatFloat: number;
27-
textureFormatFloat: number;
28-
internalFormatPackedHalfFloat: number;
29-
internalFormatHalfFloat: number;
30-
internalFormatPackedFloat: number;
31-
32-
// The format to use during a gl.readPixels call.
33-
downloadTextureFormat: number;
34-
// How many channels need to be unpacked after a gl.readPixels call.
35-
downloadUnpackNumChannels: number;
36-
37-
defaultNumChannels: number;
38-
textureTypeHalfFloat: number;
39-
}
40-
4125
export function createVertexShader(
4226
gl: WebGLRenderingContext, debug: boolean): WebGLShader {
4327
const glsl = getGlslDifferences();
@@ -69,60 +53,6 @@ export function createIndexBuffer(
6953
return webgl_util.createStaticIndexBuffer(gl, debug, triangleVertexIndices);
7054
}
7155

72-
export function getTextureConfig(
73-
// tslint:disable-next-line:no-any
74-
gl: WebGLRenderingContext, textureHalfFloatExtension?: any): TextureConfig {
75-
// tslint:disable-next-line:no-any
76-
const glany = gl as any;
77-
78-
let internalFormatFloat: number;
79-
let internalFormatHalfFloat: number;
80-
let internalFormatPackedHalfFloat: number;
81-
let internalFormatPackedFloat: number;
82-
let textureFormatFloat: number;
83-
84-
let downloadTextureFormat: number;
85-
let downloadUnpackNumChannels: number;
86-
87-
let defaultNumChannels: number;
88-
let textureTypeHalfFloat: number;
89-
90-
if (ENV.getNumber('WEBGL_VERSION') === 2) {
91-
internalFormatFloat = glany.R32F;
92-
internalFormatHalfFloat = glany.R16F;
93-
internalFormatPackedHalfFloat = glany.RGBA16F;
94-
internalFormatPackedFloat = glany.RGBA32F;
95-
textureFormatFloat = glany.RED;
96-
downloadUnpackNumChannels = 4;
97-
defaultNumChannels = 1;
98-
textureTypeHalfFloat = glany.HALF_FLOAT;
99-
} else {
100-
internalFormatFloat = gl.RGBA;
101-
internalFormatHalfFloat = gl.RGBA;
102-
internalFormatPackedHalfFloat = gl.RGBA;
103-
internalFormatPackedFloat = glany.RGBA;
104-
textureFormatFloat = gl.RGBA;
105-
downloadUnpackNumChannels = 4;
106-
defaultNumChannels = 4;
107-
textureTypeHalfFloat = textureHalfFloatExtension != null ?
108-
textureHalfFloatExtension.HALF_FLOAT_OES :
109-
null;
110-
}
111-
downloadTextureFormat = gl.RGBA;
112-
113-
return {
114-
internalFormatFloat,
115-
internalFormatHalfFloat,
116-
internalFormatPackedHalfFloat,
117-
internalFormatPackedFloat,
118-
textureFormatFloat,
119-
downloadTextureFormat,
120-
downloadUnpackNumChannels,
121-
defaultNumChannels,
122-
textureTypeHalfFloat
123-
};
124-
}
125-
12656
function createAndConfigureTexture(
12757
gl: WebGLRenderingContext, debug: boolean, width: number, height: number,
12858
internalFormat: number, textureFormat: number,

src/backends/webgl/gpgpu_util_test.ts

+5-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import {describeWithFlags} from '../../jasmine_util';
1919
import {WEBGL_ENVS} from './backend_webgl_test_registry';
2020
import {GPGPUContext} from './gpgpu_context';
2121
import * as gpgpu_util from './gpgpu_util';
22+
import * as tex_util from './tex_util';
2223

2324
describeWithFlags('gpgpu_util createWebGLContext', WEBGL_ENVS, () => {
2425
let gpgpu: GPGPUContext;
@@ -59,7 +60,7 @@ describeWithFlags('gpgpu_util createWebGLContext', WEBGL_ENVS, () => {
5960
describeWithFlags('gpgpu_util createFloat32MatrixTexture', WEBGL_ENVS, () => {
6061
it('sets the TEXTURE_WRAP S+T parameters to CLAMP_TO_EDGE', () => {
6162
const gpgpu = new GPGPUContext();
62-
const textureConfig = gpgpu_util.getTextureConfig(gpgpu.gl);
63+
const textureConfig = tex_util.getTextureConfig(gpgpu.gl);
6364
const debug = false;
6465
const tex = gpgpu_util.createFloat32MatrixTexture(
6566
gpgpu.gl, debug, 32, 32, textureConfig);
@@ -77,7 +78,7 @@ describeWithFlags('gpgpu_util createFloat32MatrixTexture', WEBGL_ENVS, () => {
7778

7879
it('sets the TEXTURE_[MIN|MAG]_FILTER parameters to NEAREST', () => {
7980
const gpgpu = new GPGPUContext();
80-
const textureConfig = gpgpu_util.getTextureConfig(gpgpu.gl);
81+
const textureConfig = tex_util.getTextureConfig(gpgpu.gl);
8182
const debug = false;
8283
const tex = gpgpu_util.createFloat32MatrixTexture(
8384
gpgpu.gl, debug, 32, 32, textureConfig);
@@ -97,7 +98,7 @@ describeWithFlags('gpgpu_util createFloat32MatrixTexture', WEBGL_ENVS, () => {
9798
describeWithFlags('gpgpu_util createPackedMatrixTexture', WEBGL_ENVS, () => {
9899
it('sets the TEXTURE_WRAP S+T parameters to CLAMP_TO_EDGE', () => {
99100
const gpgpu = new GPGPUContext();
100-
const textureConfig = gpgpu_util.getTextureConfig(gpgpu.gl);
101+
const textureConfig = tex_util.getTextureConfig(gpgpu.gl);
101102
const debug = false;
102103
const tex = gpgpu_util.createPackedMatrixTexture(
103104
gpgpu.gl, debug, 32, 32, textureConfig);
@@ -115,7 +116,7 @@ describeWithFlags('gpgpu_util createPackedMatrixTexture', WEBGL_ENVS, () => {
115116

116117
it('sets the TEXTURE_[MIN|MAG]_FILTER parameters to NEAREST', () => {
117118
const gpgpu = new GPGPUContext();
118-
const textureConfig = gpgpu_util.getTextureConfig(gpgpu.gl);
119+
const textureConfig = tex_util.getTextureConfig(gpgpu.gl);
119120
const debug = false;
120121
const tex = gpgpu_util.createPackedMatrixTexture(
121122
gpgpu.gl, debug, 32, 32, textureConfig);

src/backends/webgl/tex_util.ts

+76
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
* =============================================================================
1616
*/
1717

18+
import {ENV} from '../../environment';
1819
import {DataId, Tensor} from '../../tensor';
1920
import {BackendValues, DataType} from '../../types';
2021
import * as util from '../../util';
@@ -121,3 +122,78 @@ export function getPackedRGBAArraySizeFromMatrixShape(
121122
const [w, h] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
122123
return w * h * 4;
123124
}
125+
126+
export interface TextureConfig {
127+
internalFormatFloat: number;
128+
textureFormatFloat: number;
129+
internalFormatPackedHalfFloat: number;
130+
internalFormatHalfFloat: number;
131+
internalFormatPackedFloat: number;
132+
133+
// The format to use during a gl.readPixels call.
134+
downloadTextureFormat: number;
135+
// How many channels need to be unpacked after a gl.readPixels call.
136+
downloadUnpackNumChannels: number;
137+
138+
defaultNumChannels: number;
139+
textureTypeHalfFloat: number;
140+
textureTypeFloat: number;
141+
}
142+
143+
export function getTextureConfig(
144+
// tslint:disable-next-line:no-any
145+
gl: WebGLRenderingContext, textureHalfFloatExtension?: any): TextureConfig {
146+
// tslint:disable-next-line:no-any
147+
const glany = gl as any;
148+
149+
let internalFormatFloat: number;
150+
let internalFormatHalfFloat: number;
151+
let internalFormatPackedHalfFloat: number;
152+
let internalFormatPackedFloat: number;
153+
let textureFormatFloat: number;
154+
155+
let downloadTextureFormat: number;
156+
let downloadUnpackNumChannels: number;
157+
158+
let defaultNumChannels: number;
159+
let textureTypeHalfFloat: number;
160+
let textureTypeFloat: number;
161+
162+
if (ENV.getNumber('WEBGL_VERSION') === 2) {
163+
internalFormatFloat = glany.R32F;
164+
internalFormatHalfFloat = glany.R16F;
165+
internalFormatPackedHalfFloat = glany.RGBA16F;
166+
internalFormatPackedFloat = glany.RGBA32F;
167+
textureFormatFloat = glany.RED;
168+
downloadUnpackNumChannels = 4;
169+
defaultNumChannels = 1;
170+
textureTypeHalfFloat = glany.HALF_FLOAT;
171+
textureTypeFloat = glany.FLOAT;
172+
} else {
173+
internalFormatFloat = gl.RGBA;
174+
internalFormatHalfFloat = gl.RGBA;
175+
internalFormatPackedHalfFloat = gl.RGBA;
176+
internalFormatPackedFloat = glany.RGBA;
177+
textureFormatFloat = gl.RGBA;
178+
downloadUnpackNumChannels = 4;
179+
defaultNumChannels = 4;
180+
textureTypeHalfFloat = textureHalfFloatExtension != null ?
181+
textureHalfFloatExtension.HALF_FLOAT_OES :
182+
null;
183+
textureTypeFloat = gl.FLOAT;
184+
}
185+
downloadTextureFormat = gl.RGBA;
186+
187+
return {
188+
internalFormatFloat,
189+
internalFormatHalfFloat,
190+
internalFormatPackedHalfFloat,
191+
internalFormatPackedFloat,
192+
textureFormatFloat,
193+
downloadTextureFormat,
194+
downloadUnpackNumChannels,
195+
defaultNumChannels,
196+
textureTypeHalfFloat,
197+
textureTypeFloat
198+
};
199+
}

0 commit comments

Comments
 (0)