Skip to content

Commit

Permalink
Merge pull request BrainJS#148 from BrainJS/valid-options
Browse files Browse the repository at this point in the history
Handles User Object Validation BrainJS#142
  • Loading branch information
freddyC authored Feb 18, 2018
2 parents 9da27e7 + 23af33d commit 2a0093e
Show file tree
Hide file tree
Showing 7 changed files with 11,608 additions and 11,295 deletions.
21 changes: 12 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,16 @@ var output = net.run({ r: 1, g: 0.4, b: 0 }); // { white: 0.81, black: 0.18 }

```javascript
net.train(data, {
iterations: 20000, // the maximum times to iterate the training data
errorThresh: 0.005, // the acceptable error percentage from training data
log: false, // true to use console.log, when a function is supplied it is used
logPeriod: 10, // iterations between logging out
learningRate: 0.3, // scales with delta to effect traiing rate
momentum: 0.1, // scales with next layer's change value
callback: null, // a periodic call back that can be triggered while training
callbackPeriod: 10, // the number of iterations through the training data between callback calls
timeout: Infinity // the max number of milliseconds to train for
// Defaults values --> expected validation
iterations: 20000, // the maximum times to iterate the training data --> number greater than 0
errorThresh: 0.005, // the acceptable error percentage from training data --> number between 0 and 1
log: false, // true to use console.log, when a function is supplied it is used --> Either true or a function
logPeriod: 10, // iterations between logging out --> number greater than 0
learningRate: 0.3, // scales with delta to effect traiing rate --> number between 0 and 1
momentum: 0.1, // scales with next layer's change value --> number between 0 and 1
callback: null, // a periodic call back that can be triggered while training --> null or function
callbackPeriod: 10, // the number of iterations through the training data between callback calls --> number greater than 0
timeout: Infinity // the max number of milliseconds to train for --> number greater than 0
});
```

Expand All @@ -151,6 +152,8 @@ The momentum is similar to learning rate, expecting a value from `0` to `1` as w

Any of these training options can be passed into the constructor or passed into the `updateTrainingOptions(opts)` method and they will be saved on the network and used any time you trian. If you save your network to json, these training options are saved and restored as well (except for callback and log, callback will be forgoten and log will be restored using console.log).

There is a boolean property called `invalidTrainOptsShouldThrow` that by default is set to true. While true if you enter a training option that is outside the normal range an error will be thrown with a message about the option you sent. When set to false no error is sent but a message is still sent to `console.warn` with the information.

### Async Training
`trainAsync()` takes the same arguments as train (data and options). Instead of returning the results object from training it returns a promise that when resolved will return the training results object.

Expand Down
22,589 changes: 11,367 additions & 11,222 deletions browser.js

Large diffs are not rendered by default.

122 changes: 62 additions & 60 deletions browser.min.js

Large diffs are not rendered by default.

49 changes: 48 additions & 1 deletion dist/neural-network.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/neural-network.js.map

Large diffs are not rendered by default.

28 changes: 27 additions & 1 deletion src/neural-network.js
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,31 @@ export default class NeuralNetwork {
};
}

/**
*
* @param options
* @param boolean
* @private
*/
static _validateTrainingOptions(options) {
var validations = {
iterations: (val) => { return typeof val === 'number' && val > 0; },
errorThresh: (val) => { return typeof val === 'number' && val > 0 && val < 1; },
log: (val) => { return typeof val === 'function' || typeof val === 'boolean'; },
logPeriod: (val) => { return typeof val === 'number' && val > 0; },
learningRate: (val) => { return typeof val === 'number' && val > 0 && val < 1; },
momentum: (val) => { return typeof val === 'number' && val > 0 && val < 1; },
callback: (val) => { return typeof val === 'function' || val === null },
callbackPeriod: (val) => { return typeof val === 'number' && val > 0; },
timeout: (val) => { return typeof val === 'number' && val > 0 }
};
Object.keys(NeuralNetwork.trainDefaults).forEach(key => {
if (validations.hasOwnProperty(key) && !validations[key](options[key])) {
throw new Error(`[${key}, ${options[key]}] is out of normal training range, your network will probably not train.`);
}
});
}

constructor(options = {}) {
Object.assign(this, this.constructor.defaults, options);
this.hiddenSizes = options.hiddenLayers;
Expand Down Expand Up @@ -293,7 +318,8 @@ export default class NeuralNetwork {
* activation: ['sigmoid', 'relu', 'leaky-relu', 'tanh']
*/
_updateTrainingOptions(opts) {
Object.keys(NeuralNetwork.trainDefaults).forEach(opt => this.trainOpts[opt] = opts[opt] || this.trainOpts[opt]);
Object.keys(NeuralNetwork.trainDefaults).forEach(opt => this.trainOpts[opt] = (opts.hasOwnProperty(opt)) ? opts[opt] : this.trainOpts[opt]);
NeuralNetwork._validateTrainingOptions(this.trainOpts);
this._setLogMethod(opts.log || this.trainOpts.log);
this.activation = opts.activation || this.activation;
}
Expand Down
92 changes: 91 additions & 1 deletion test/base/trainopts.js
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,94 @@ describe('train() and trainAsync() use the same private methods', () => {
done()
});
});
});
});

describe('training options validation', () => {
it('iterations validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ iterations: 'should be a string' }) });
assert.throws(() => { net._updateTrainingOptions({ iterations: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ iterations: false }) });
assert.throws(() => { net._updateTrainingOptions({ iterations: -1 }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ iterations: 5000 }) });
});

it('errorThresh validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ errorThresh: 'no strings'}) });
assert.throws(() => { net._updateTrainingOptions({ errorThresh: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ errorThresh: 5}) });
assert.throws(() => { net._updateTrainingOptions({ errorThresh: -1}) });
assert.throws(() => { net._updateTrainingOptions({ errorThresh: false}) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ errorThresh: 0.008}) });
});

it('log validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ log: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ log: 4 }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ log: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ log: () => {} }) });
});

it('logPeriod validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ logPeriod: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ logPeriod: -50 }) });
assert.throws(() => { net._updateTrainingOptions({ logPeriod: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ logPeriod: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ logPeriod: 40 }) });
});

it('learningRate validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ learningRate: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ learningRate: -50 }) });
assert.throws(() => { net._updateTrainingOptions({ learningRate: 50 }) });
assert.throws(() => { net._updateTrainingOptions({ learningRate: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ learningRate: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ learningRate: 0.5 }) });
});

it('momentum validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ momentum: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ momentum: -50 }) });
assert.throws(() => { net._updateTrainingOptions({ momentum: 50 }) });
assert.throws(() => { net._updateTrainingOptions({ momentum: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ momentum: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ momentum: 0.8 }) });
});

it('callback validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ callback: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ callback: 4 }) });
assert.throws(() => { net._updateTrainingOptions({ callback: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ callback: null }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ callback: () => {} }) });
});

it('callbackPeriod validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: -50 }) });
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ callbackPeriod: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ callbackPeriod: 40 }) });
});

it('timeout validation', () => {
let net = new brain.NeuralNetwork();
assert.throws(() => { net._updateTrainingOptions({ timeout: 'no strings' }) });
assert.throws(() => { net._updateTrainingOptions({ timeout: -50 }) });
assert.throws(() => { net._updateTrainingOptions({ timeout: () => {} }) });
assert.throws(() => { net._updateTrainingOptions({ timeout: false }) });
assert.doesNotThrow(() => { net._updateTrainingOptions({ timeout: 40 }) });
});

it('should handle unsupported options', () => {
let net = new brain.NeuralNetwork();
assert.doesNotThrow(() => { net._updateTrainingOptions({ fakeProperty: 'should be handled fine' }) });
})
});

0 comments on commit 2a0093e

Please sign in to comment.