forked from BrainJS/brain.js
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: GPU and add official test for it
- Loading branch information
1 parent
0734989
commit a770fef
Showing
7 changed files
with
121 additions
and
169 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
* license: MIT (http://opensource.org/licenses/MIT) | ||
* author: Heather Arthur <[email protected]> | ||
* homepage: https://github.com/brainjs/brain.js#readme | ||
* version: 1.5.0-alpha | ||
* version: 1.5.0 | ||
* | ||
* acorn: | ||
* license: MIT (http://opensource.org/licenses/MIT) | ||
|
@@ -41,7 +41,7 @@ | |
* license: MIT (http://opensource.org/licenses/MIT) | ||
* author: The gpu.js Team | ||
* homepage: http://gpu.rocks/ | ||
* version: 1.10.2 | ||
* version: 1.10.3 | ||
* | ||
* ieee754: | ||
* license: BSD-3-Clause (http://opensource.org/licenses/BSD-3-Clause) | ||
|
@@ -587,9 +587,9 @@ var NeuralNetworkGPU = function (_NeuralNetwork) { | |
|
||
|
||
_createClass(NeuralNetworkGPU, [{ | ||
key: '_initialize', | ||
value: function _initialize() { | ||
_get(NeuralNetworkGPU.prototype.__proto__ || Object.getPrototypeOf(NeuralNetworkGPU.prototype), '_initialize', this).call(this); | ||
key: 'initialize', | ||
value: function initialize() { | ||
_get(NeuralNetworkGPU.prototype.__proto__ || Object.getPrototypeOf(NeuralNetworkGPU.prototype), 'initialize', this).call(this); | ||
this.buildRunInput(); | ||
this.buildCalculateDeltas(); | ||
this.buildGetChanges(); | ||
|
@@ -608,8 +608,8 @@ var NeuralNetworkGPU = function (_NeuralNetwork) { | |
*/ | ||
|
||
}, { | ||
key: '_trainPattern', | ||
value: function _trainPattern(input, target, logErrorRate) { | ||
key: 'trainPattern', | ||
value: function trainPattern(input, target, logErrorRate) { | ||
// forward propagate | ||
this.runInput(input); | ||
|
||
|
@@ -657,7 +657,7 @@ var NeuralNetworkGPU = function (_NeuralNetwork) { | |
}); | ||
} | ||
|
||
this._texturizeInputData = this.gpu.createKernel(function (value) { | ||
this.texturizeInputData = this.gpu.createKernel(function (value) { | ||
return value[this.thread.x]; | ||
}, { | ||
output: [this.sizes[1]], | ||
|
@@ -856,50 +856,23 @@ var NeuralNetworkGPU = function (_NeuralNetwork) { | |
value: function run(input) { | ||
if (!this.isRunnable) return null; | ||
if (this.inputLookup) { | ||
input = _lookup2.default.toArray(this.inputLookup, input); | ||
input = _lookup2.default.toArray(this.inputLookup, input, this.inputLookupLength); | ||
} | ||
var inputTexture = this._texturizeInputData(input); | ||
var inputTexture = this.texturizeInputData(input); | ||
var outputTextures = this.runInput(inputTexture); | ||
var output = outputTextures.toArray(this.gpu); | ||
var output = void 0; | ||
if (outputTextures.toArray) { | ||
output = outputTextures.toArray(this.gpu); | ||
} else { | ||
output = outputTextures; | ||
} | ||
|
||
if (this.outputLookup) { | ||
output = _lookup2.default.toObject(this.outputLookup, output); | ||
} | ||
return output; | ||
} | ||
|
||
/** | ||
* | ||
* @param data | ||
* Verifies network sizes are initilaized | ||
* If they are not it will initialize them based off the data set. | ||
*/ | ||
|
||
}, { | ||
key: '_verifyIsInitialized', | ||
value: function _verifyIsInitialized(data) { | ||
var _this2 = this; | ||
|
||
if (this.sizes) return; | ||
|
||
this.sizes = []; | ||
if (!data[0].size) { | ||
data[0].size = { input: data[0].input.length, output: data[0].output.length }; | ||
} | ||
|
||
this.sizes.push(data[0].size.input); | ||
if (!this.hiddenLayers) { | ||
this.sizes.push(Math.max(3, Math.floor(data[0].size.input / 2))); | ||
} else { | ||
this.hiddenLayers.forEach(function (size) { | ||
_this2.sizes.push(size); | ||
}); | ||
} | ||
this.sizes.push(data[0].size.output); | ||
|
||
this._initialize(); | ||
} | ||
|
||
/** | ||
* | ||
* @param data | ||
|
@@ -909,11 +882,11 @@ var NeuralNetworkGPU = function (_NeuralNetwork) { | |
*/ | ||
|
||
}, { | ||
key: '_prepTraining', | ||
value: function _prepTraining(data, options) { | ||
var _this3 = this; | ||
key: 'prepTraining', | ||
value: function prepTraining(data, options) { | ||
var _this2 = this; | ||
|
||
this._updateTrainingOptions(options); | ||
this.updateTrainingOptions(options); | ||
data = this.formatData(data); | ||
var endTime = Date.now() + this.trainOpts.timeout; | ||
|
||
|
@@ -922,7 +895,7 @@ var NeuralNetworkGPU = function (_NeuralNetwork) { | |
iterations: 0 | ||
}; | ||
|
||
this._verifyIsInitialized(data); | ||
this.verifyIsInitialized(data); | ||
|
||
var texturizeOutputData = this.gpu.createKernel(function (value) { | ||
return value[this.thread.x]; | ||
|
@@ -936,8 +909,7 @@ var NeuralNetworkGPU = function (_NeuralNetwork) { | |
return { | ||
data: data.map(function (set) { | ||
return { | ||
size: set.size, | ||
input: _this3._texturizeInputData(set.input), | ||
input: _this2.texturizeInputData(set.input), | ||
output: texturizeOutputData(set.output) | ||
}; | ||
}), | ||
|
@@ -9833,7 +9805,7 @@ module.exports = function (_KernelBase) { | |
* @constructor CPUKernel | ||
* | ||
* @desc Kernel Implementation for CPU. | ||
* | ||
* | ||
* <p>Instantiates properties to the CPU Kernel.</p> | ||
* | ||
* @extends KernelBase | ||
|
@@ -9874,7 +9846,7 @@ module.exports = function (_KernelBase) { | |
* @function | ||
* @name validateOptions | ||
* | ||
* @desc Validate options related to CPU Kernel, such as | ||
* @desc Validate options related to CPU Kernel, such as | ||
* dimensions size, and auto dimension support. | ||
* | ||
*/ | ||
|
@@ -9906,8 +9878,8 @@ module.exports = function (_KernelBase) { | |
* @function | ||
* @name build | ||
* | ||
* @desc Builds the Kernel, by generating the kernel | ||
* string using thread dimensions, and arguments | ||
* @desc Builds the Kernel, by generating the kernel | ||
* string using thread dimensions, and arguments | ||
* supplied to the kernel. | ||
* | ||
* <p>If the graphical flag is enabled, canvas is used.</p> | ||
|
@@ -9921,16 +9893,23 @@ module.exports = function (_KernelBase) { | |
this.setupParams(arguments); | ||
this.validateOptions(); | ||
var canvas = this._canvas; | ||
this._canvasCtx = canvas.getContext('2d'); | ||
if (canvas) { | ||
// if node or canvas is not found, don't die | ||
this._canvasCtx = canvas.getContext('2d'); | ||
} | ||
var threadDim = this.threadDim = utils.clone(this.output); | ||
|
||
while (threadDim.length < 3) { | ||
threadDim.push(1); | ||
} | ||
|
||
if (this.graphical) { | ||
canvas.width = threadDim[0]; | ||
canvas.height = threadDim[1]; | ||
var _canvas = this._canvas; | ||
if (!_canvas) { | ||
throw new Error('no canvas available for using graphical output'); | ||
} | ||
_canvas.width = threadDim[0]; | ||
_canvas.height = threadDim[1]; | ||
this._imageData = this._canvasCtx.createImageData(threadDim[0], threadDim[1]); | ||
this._colorData = new Uint8ClampedArray(threadDim[0] * threadDim[1] * 4); | ||
} | ||
|
@@ -9979,7 +9958,7 @@ module.exports = function (_KernelBase) { | |
* @name getKernelString | ||
* | ||
* @desc Generates kernel string for this kernel program. | ||
* | ||
* | ||
* <p>If sub-kernels are supplied, they are also factored in. | ||
* This string can be saved by calling the `toString` method | ||
* and then can be reused later.</p> | ||
|
@@ -14834,7 +14813,7 @@ module.exports = function (_KernelBase) { | |
floatOutput: this.floatOutput | ||
}, paramDim); | ||
|
||
result.push('uniform sampler2D user_' + paramName, 'ivec2 user_' + paramName + 'Size = vec2(' + paramSize[0] + ', ' + paramSize[1] + ')', 'ivec3 user_' + paramName + 'Dim = vec3(' + paramDim[0] + ', ' + paramDim[1] + ', ' + paramDim[2] + ')', 'uniform int user_' + paramName + 'BitRatio'); | ||
result.push('uniform sampler2D user_' + paramName, 'ivec2 user_' + paramName + 'Size = ivec2(' + paramSize[0] + ', ' + paramSize[1] + ')', 'ivec3 user_' + paramName + 'Dim = ivec3(' + paramDim[0] + ', ' + paramDim[1] + ', ' + paramDim[2] + ')', 'uniform int user_' + paramName + 'BitRatio'); | ||
} else if (paramType === 'Integer') { | ||
result.push('float user_' + paramName + ' = ' + param + '.0'); | ||
} else if (paramType === 'Float') { | ||
|
@@ -16249,10 +16228,6 @@ module.exports = function (_WebGLKernel) { | |
}, paramDim); | ||
|
||
result.push('uniform highp sampler2D user_' + paramName, 'highp ivec2 user_' + paramName + 'Size = ivec2(' + paramSize[0] + ', ' + paramSize[1] + ')', 'highp ivec3 user_' + paramName + 'Dim = ivec3(' + paramDim[0] + ', ' + paramDim[1] + ', ' + paramDim[2] + ')', 'uniform highp int user_' + paramName + 'BitRatio'); | ||
|
||
if (paramType === 'Array') { | ||
result.push('uniform highp int user_' + paramName + 'BitRatio'); | ||
} | ||
} else if (paramType === 'Integer') { | ||
result.push('highp float user_' + paramName + ' = ' + param + '.0'); | ||
} else if (paramType === 'Float') { | ||
|
Oops, something went wrong.