Skip to content

Commit

Permalink
fix: GPU and add official test for it
Browse files Browse the repository at this point in the history
  • Loading branch information
robertleeplummerjr committed Nov 18, 2018
1 parent 0734989 commit a770fef
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 169 deletions.
101 changes: 38 additions & 63 deletions browser.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* license: MIT (http://opensource.org/licenses/MIT)
* author: Heather Arthur <[email protected]>
* homepage: https://github.com/brainjs/brain.js#readme
* version: 1.5.0-alpha
* version: 1.5.0
*
* acorn:
* license: MIT (http://opensource.org/licenses/MIT)
Expand Down Expand Up @@ -41,7 +41,7 @@
* license: MIT (http://opensource.org/licenses/MIT)
* author: The gpu.js Team
* homepage: http://gpu.rocks/
* version: 1.10.2
* version: 1.10.3
*
* ieee754:
* license: BSD-3-Clause (http://opensource.org/licenses/BSD-3-Clause)
Expand Down Expand Up @@ -587,9 +587,9 @@ var NeuralNetworkGPU = function (_NeuralNetwork) {


_createClass(NeuralNetworkGPU, [{
key: '_initialize',
value: function _initialize() {
_get(NeuralNetworkGPU.prototype.__proto__ || Object.getPrototypeOf(NeuralNetworkGPU.prototype), '_initialize', this).call(this);
key: 'initialize',
value: function initialize() {
_get(NeuralNetworkGPU.prototype.__proto__ || Object.getPrototypeOf(NeuralNetworkGPU.prototype), 'initialize', this).call(this);
this.buildRunInput();
this.buildCalculateDeltas();
this.buildGetChanges();
Expand All @@ -608,8 +608,8 @@ var NeuralNetworkGPU = function (_NeuralNetwork) {
*/

}, {
key: '_trainPattern',
value: function _trainPattern(input, target, logErrorRate) {
key: 'trainPattern',
value: function trainPattern(input, target, logErrorRate) {
// forward propagate
this.runInput(input);

Expand Down Expand Up @@ -657,7 +657,7 @@ var NeuralNetworkGPU = function (_NeuralNetwork) {
});
}

this._texturizeInputData = this.gpu.createKernel(function (value) {
this.texturizeInputData = this.gpu.createKernel(function (value) {
return value[this.thread.x];
}, {
output: [this.sizes[1]],
Expand Down Expand Up @@ -856,50 +856,23 @@ var NeuralNetworkGPU = function (_NeuralNetwork) {
value: function run(input) {
if (!this.isRunnable) return null;
if (this.inputLookup) {
input = _lookup2.default.toArray(this.inputLookup, input);
input = _lookup2.default.toArray(this.inputLookup, input, this.inputLookupLength);
}
var inputTexture = this._texturizeInputData(input);
var inputTexture = this.texturizeInputData(input);
var outputTextures = this.runInput(inputTexture);
var output = outputTextures.toArray(this.gpu);
var output = void 0;
if (outputTextures.toArray) {
output = outputTextures.toArray(this.gpu);
} else {
output = outputTextures;
}

if (this.outputLookup) {
output = _lookup2.default.toObject(this.outputLookup, output);
}
return output;
}

/**
*
* @param data
* Verifies network sizes are initilaized
* If they are not it will initialize them based off the data set.
*/

}, {
key: '_verifyIsInitialized',
value: function _verifyIsInitialized(data) {
var _this2 = this;

if (this.sizes) return;

this.sizes = [];
if (!data[0].size) {
data[0].size = { input: data[0].input.length, output: data[0].output.length };
}

this.sizes.push(data[0].size.input);
if (!this.hiddenLayers) {
this.sizes.push(Math.max(3, Math.floor(data[0].size.input / 2)));
} else {
this.hiddenLayers.forEach(function (size) {
_this2.sizes.push(size);
});
}
this.sizes.push(data[0].size.output);

this._initialize();
}

/**
*
* @param data
Expand All @@ -909,11 +882,11 @@ var NeuralNetworkGPU = function (_NeuralNetwork) {
*/

}, {
key: '_prepTraining',
value: function _prepTraining(data, options) {
var _this3 = this;
key: 'prepTraining',
value: function prepTraining(data, options) {
var _this2 = this;

this._updateTrainingOptions(options);
this.updateTrainingOptions(options);
data = this.formatData(data);
var endTime = Date.now() + this.trainOpts.timeout;

Expand All @@ -922,7 +895,7 @@ var NeuralNetworkGPU = function (_NeuralNetwork) {
iterations: 0
};

this._verifyIsInitialized(data);
this.verifyIsInitialized(data);

var texturizeOutputData = this.gpu.createKernel(function (value) {
return value[this.thread.x];
Expand All @@ -936,8 +909,7 @@ var NeuralNetworkGPU = function (_NeuralNetwork) {
return {
data: data.map(function (set) {
return {
size: set.size,
input: _this3._texturizeInputData(set.input),
input: _this2.texturizeInputData(set.input),
output: texturizeOutputData(set.output)
};
}),
Expand Down Expand Up @@ -9833,7 +9805,7 @@ module.exports = function (_KernelBase) {
* @constructor CPUKernel
*
* @desc Kernel Implementation for CPU.
*
*
* <p>Instantiates properties to the CPU Kernel.</p>
*
* @extends KernelBase
Expand Down Expand Up @@ -9874,7 +9846,7 @@ module.exports = function (_KernelBase) {
* @function
* @name validateOptions
*
* @desc Validate options related to CPU Kernel, such as
* @desc Validate options related to CPU Kernel, such as
* dimensions size, and auto dimension support.
*
*/
Expand Down Expand Up @@ -9906,8 +9878,8 @@ module.exports = function (_KernelBase) {
* @function
* @name build
*
* @desc Builds the Kernel, by generating the kernel
* string using thread dimensions, and arguments
* @desc Builds the Kernel, by generating the kernel
* string using thread dimensions, and arguments
* supplied to the kernel.
*
* <p>If the graphical flag is enabled, canvas is used.</p>
Expand All @@ -9921,16 +9893,23 @@ module.exports = function (_KernelBase) {
this.setupParams(arguments);
this.validateOptions();
var canvas = this._canvas;
this._canvasCtx = canvas.getContext('2d');
if (canvas) {
// if node or canvas is not found, don't die
this._canvasCtx = canvas.getContext('2d');
}
var threadDim = this.threadDim = utils.clone(this.output);

while (threadDim.length < 3) {
threadDim.push(1);
}

if (this.graphical) {
canvas.width = threadDim[0];
canvas.height = threadDim[1];
var _canvas = this._canvas;
if (!_canvas) {
throw new Error('no canvas available for using graphical output');
}
_canvas.width = threadDim[0];
_canvas.height = threadDim[1];
this._imageData = this._canvasCtx.createImageData(threadDim[0], threadDim[1]);
this._colorData = new Uint8ClampedArray(threadDim[0] * threadDim[1] * 4);
}
Expand Down Expand Up @@ -9979,7 +9958,7 @@ module.exports = function (_KernelBase) {
* @name getKernelString
*
* @desc Generates kernel string for this kernel program.
*
*
* <p>If sub-kernels are supplied, they are also factored in.
* This string can be saved by calling the `toString` method
* and then can be reused later.</p>
Expand Down Expand Up @@ -14834,7 +14813,7 @@ module.exports = function (_KernelBase) {
floatOutput: this.floatOutput
}, paramDim);

result.push('uniform sampler2D user_' + paramName, 'ivec2 user_' + paramName + 'Size = vec2(' + paramSize[0] + ', ' + paramSize[1] + ')', 'ivec3 user_' + paramName + 'Dim = vec3(' + paramDim[0] + ', ' + paramDim[1] + ', ' + paramDim[2] + ')', 'uniform int user_' + paramName + 'BitRatio');
result.push('uniform sampler2D user_' + paramName, 'ivec2 user_' + paramName + 'Size = ivec2(' + paramSize[0] + ', ' + paramSize[1] + ')', 'ivec3 user_' + paramName + 'Dim = ivec3(' + paramDim[0] + ', ' + paramDim[1] + ', ' + paramDim[2] + ')', 'uniform int user_' + paramName + 'BitRatio');
} else if (paramType === 'Integer') {
result.push('float user_' + paramName + ' = ' + param + '.0');
} else if (paramType === 'Float') {
Expand Down Expand Up @@ -16249,10 +16228,6 @@ module.exports = function (_WebGLKernel) {
}, paramDim);

result.push('uniform highp sampler2D user_' + paramName, 'highp ivec2 user_' + paramName + 'Size = ivec2(' + paramSize[0] + ', ' + paramSize[1] + ')', 'highp ivec3 user_' + paramName + 'Dim = ivec3(' + paramDim[0] + ', ' + paramDim[1] + ', ' + paramDim[2] + ')', 'uniform highp int user_' + paramName + 'BitRatio');

if (paramType === 'Array') {
result.push('uniform highp int user_' + paramName + 'BitRatio');
}
} else if (paramType === 'Integer') {
result.push('highp float user_' + paramName + ' = ' + param + '.0');
} else if (paramType === 'Float') {
Expand Down
Loading

0 comments on commit a770fef

Please sign in to comment.