Skip to content

Commit

Permalink
fix: Minor fix for future compatibility of using CrossValidate with…
Browse files Browse the repository at this point in the history
… other neural networks

Bump version number, and as well bump bower version number
  • Loading branch information
robertleeplummerjr committed Sep 22, 2018
1 parent b2bec1e commit 5797b87
Show file tree
Hide file tree
Showing 12 changed files with 49 additions and 53 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ With multiple networks you can train in parallel like this:
### Cross Validation
[Cross Validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)) can provide a less fragile way of training on larger data sets. The brain.js api provides Cross Validation in this example:
```js
const crossValidate = new CrossValidate(brain.NeuralNetwork);
const stats = crossValidate.train(data, networkOptions, trainingOptions, k); //note k (or KFolds) is optional
const net = crossValidate.toNetwork();
const crossValidate = new CrossValidate(brain.NeuralNetwork, networkOptions);
const stats = crossValidate.train(data, trainingOptions, k); //note k (or KFolds) is optional
const net = crossValidate.toNeuralNetwork();


// optionally later
Expand Down
2 changes: 1 addition & 1 deletion bower.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@
"node_modules",
"test"
],
"version": "1.1.1"
"version": "1.3.1"
}
18 changes: 9 additions & 9 deletions browser.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* license: MIT (http://opensource.org/licenses/MIT)
* author: Heather Arthur <[email protected]>
* homepage: https://github.com/brainjs/brain.js#readme
* version: 1.3.0
* version: 1.3.1
*
* acorn:
* license: MIT (http://opensource.org/licenses/MIT)
Expand Down Expand Up @@ -127,17 +127,18 @@ var CrossValidate = function () {
/**
*
* @param {NeuralNetwork|constructor} Classifier
* @param {object} [options]
*/
function CrossValidate(Classifier) {
function CrossValidate(Classifier, options) {
_classCallCheck(this, CrossValidate);

this.Classifier = Classifier;
this.options = options;
this.json = null;
}

/**
*
* @param {object} options
* @param {object} trainOpts
* @param {object} trainSet
* @param {object} testSet
Expand All @@ -147,8 +148,8 @@ var CrossValidate = function () {

_createClass(CrossValidate, [{
key: "testPartition",
value: function testPartition(options, trainOpts, trainSet, testSet) {
var classifier = new this.Classifier(options);
value: function testPartition(trainOpts, trainSet, testSet) {
var classifier = new this.Classifier(this.options);
var beginTrain = Date.now();
var trainingStats = classifier.train(trainSet, trainOpts);
var beginTest = Date.now();
Expand Down Expand Up @@ -188,7 +189,6 @@ var CrossValidate = function () {
/**
*
* @param {object} data
* @param {object} options
* @param {object} trainOpts
* @param {number} [k]
* @returns {
Expand All @@ -214,7 +214,7 @@ var CrossValidate = function () {

}, {
key: "train",
value: function train(data, options, trainOpts, k) {
value: function train(data, trainOpts, k) {
k = k || 4;
var size = data.length / k;

Expand Down Expand Up @@ -291,8 +291,8 @@ var CrossValidate = function () {
};
}
}, {
key: "toNetwork",
value: function toNetwork() {
key: "toNeuralNetwork",
value: function toNeuralNetwork() {
return this.fromJSON(this.json);
}
}, {
Expand Down
8 changes: 4 additions & 4 deletions browser.min.js

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions dist/cross-validate.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/cross-validate.js.map

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions examples-typescript/cross-validate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import * as assert from 'assert';
import * as brain from '../index';

const trainingData = [
// xor
// xor data, repeating to simulate that we have a lot of data
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
Expand All @@ -29,17 +29,17 @@ const trainingData = [

const netOptions = {
hiddenSizes: [3]
} as brain.INeuralNetworkDefaultOptions;
} as brain.INeuralNetworkOptions;

const trainingOptions = {
iterations: 20000,
log: details => console.log(details)
} as brain.INeuralNetworkTrainingOptions;

const crossValidate = new brain.CrossValidate(brain.NeuralNetwork);
const stats = crossValidate.train(trainingData, netOptions, trainingOptions);
const crossValidate = new brain.CrossValidate(brain.NeuralNetwork, netOptions);
const stats = crossValidate.train(trainingData, trainingOptions);
console.log(stats);
const net = crossValidate.toNetwork();
const net = crossValidate.toNeuralNetwork();
const result01 = net.run([0, 1]);
const result00 = net.run([0, 0]);
const result11 = net.run([1, 1]);
Expand Down
8 changes: 4 additions & 4 deletions examples/cross-validate.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ const assert = require('assert');
const brain = require('../dist/index').default;

const trainingData = [
// xor
// xor data, repeating to simulate that we have a lot of data
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
Expand Down Expand Up @@ -36,10 +36,10 @@ const trainingOptions = {
log: details => console.log(details)
};

const crossValidate = new brain.CrossValidate(brain.NeuralNetwork);
const stats = crossValidate.train(trainingData, netOptions, trainingOptions);
const crossValidate = new brain.CrossValidate(brain.NeuralNetwork, netOptions);
const stats = crossValidate.train(trainingData, trainingOptions);
console.log(stats);
const net = crossValidate.toNetwork();
const net = crossValidate.toNeuralNetwork();
const result01 = net.run([0, 1]);
const result00 = net.run([0, 0]);
const result11 = net.run([1, 1]);
Expand Down
14 changes: 6 additions & 8 deletions index.d.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* NeuralNetwork section */
export interface INeuralNetworkDefaultOptions {
export interface INeuralNetworkOptions {
binaryThresh?: number;
hiddenLayers?: number[];
activation?: NeuralNetworkActivation;
Expand Down Expand Up @@ -45,7 +45,7 @@ export interface INeuralNetworkTrainingData {
export type NeuralNetworkTrainingValue = number[];

export class NeuralNetwork {
public constructor(options?: INeuralNetworkDefaultOptions);
public constructor(options?: INeuralNetworkOptions);
public train(data: INeuralNetworkTrainingData[], options?: INeuralNetworkTrainingOptions): INeuralNetworkState;
public train<T>(data: T, options?: INeuralNetworkTrainingOptions): INeuralNetworkState;
public trainAsync(data: INeuralNetworkTrainingData, options?: INeuralNetworkTrainingOptions): Promise<INeuralNetworkState>;
Expand Down Expand Up @@ -83,22 +83,20 @@ export interface ICrossValidationTestPartitionResults {
}

export class CrossValidate {
public constructor(Classifier: typeof NeuralNetwork);
public constructor(Classifier: typeof NeuralNetwork, options?: INeuralNetworkOptions);
public fromJSON(json: ICrossValidateJSON): NeuralNetwork;
public toJSON(): ICrossValidateJSON;
public train(
data: INeuralNetworkTrainingData[],
networkOptions: INeuralNetworkDefaultOptions,
trainingOptions: INeuralNetworkTrainingOptions,
k?: number): ICrossValidateStats;
public train<T>(
data: T,
networkOptions: INeuralNetworkDefaultOptions,
trainingOptions: INeuralNetworkTrainingOptions,
k?: number): ICrossValidateStats;
public testPartition(): ICrossValidationTestPartitionResults;
public toNetwork(): NeuralNetwork;
public toNetwork<T>(): T;
public toNeuralNetwork(): NeuralNetwork;
public toNeuralNetwork<T>(): T;
}

/* TrainStream section */
Expand All @@ -118,7 +116,7 @@ export class TrainStream {
/* recurrent section */
export type RNNTrainingValue = string;
export type RNNTimeStepTrainingValue = NeuralNetworkTrainingValue | number | number[] | number[][];
export interface IRNNDefaultOptions extends INeuralNetworkDefaultOptions {
export interface IRNNDefaultOptions extends INeuralNetworkOptions {
inputSize?: number;
outputSize?: number;
}
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "brain.js",
"description": "Neural network library",
"version": "1.3.0",
"version": "1.3.1",
"author": "Heather Arthur <[email protected]>",
"repository": {
"type": "git",
Expand Down
14 changes: 7 additions & 7 deletions src/cross-validate.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@ export default class CrossValidate {
/**
*
* @param {NeuralNetwork|constructor} Classifier
* @param {object} [options]
*/
constructor(Classifier) {
constructor(Classifier, options) {
this.Classifier = Classifier;
this.options = options;
this.json = null;
}

/**
*
* @param {object} options
* @param {object} trainOpts
* @param {object} trainSet
* @param {object} testSet
* @returns {void|*}
*/
testPartition(options, trainOpts, trainSet, testSet) {
let classifier = new this.Classifier(options);
testPartition(trainOpts, trainSet, testSet) {
let classifier = new this.Classifier(this.options);
let beginTrain = Date.now();
let trainingStats = classifier.train(trainSet, trainOpts);
let beginTest = Date.now();
Expand Down Expand Up @@ -55,7 +56,6 @@ export default class CrossValidate {
/**
*
* @param {object} data
* @param {object} options
* @param {object} trainOpts
* @param {number} [k]
* @returns {
Expand All @@ -78,7 +78,7 @@ export default class CrossValidate {
* }
* }
*/
train(data, options, trainOpts, k) {
train(data, trainOpts, k) {
k = k || 4;
let size = data.length / k;

Expand Down Expand Up @@ -156,7 +156,7 @@ export default class CrossValidate {
};
}

toNetwork() {
toNeuralNetwork() {
return this.fromJSON(this.json);
}

Expand Down
2 changes: 0 additions & 2 deletions test/base/json.js
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,6 @@ describe('default net json', () => {
});

it('training options timeout', () => {
console.log(originalNet.trainOpts.timeout);
console.log(serializedNet.trainOpts.timeout);
assert.equal(originalNet.trainOpts.timeout, serializedNet.trainOpts.timeout, `originalNet.trainOpts are: ${originalNet.trainOpts.timeout} serializedNet should be the same but are: ${serializedNet.trainOpts.timeout}`);
});
});
Expand Down

0 comments on commit 5797b87

Please sign in to comment.