forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstatistics_test.ts
90 lines (82 loc) · 3.91 KB
/
statistics_test.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
import {Dataset} from './dataset';
import {TestDataset} from './dataset_test';
import {computeDatasetStatistics, scaleTo01, TabularRecord} from './statistics';
import {describeAllEnvs} from './util/test_utils';
describeAllEnvs('makeDatasetStatistics', () => {
it('computes numeric min and max over numbers, arrays, and Tensors',
async () => {
const ds = new TestDataset().skip(55) as Dataset<TabularRecord>;
const stats = await computeDatasetStatistics(ds);
expect(stats['number'].min).toEqual(55);
expect(stats['number'].max).toEqual(99);
// The TestDataset includes cubes of the indices
expect(stats['numberArray'].min).toEqual(55);
expect(stats['numberArray'].max).toEqual(99 * 99 * 99);
expect(stats['Tensor'].min).toEqual(55);
expect(stats['Tensor'].max).toEqual(99 * 99 * 99);
});
it('computes numeric mean and variance over numbers, arrays, and Tensors',
async () => {
// According to https://github.com/tensorflow/tfjs/issues/658, unit tests
// for mean and variance will use small numbers.
const ds = new TestDataset().take(10) as Dataset<TabularRecord>;
const stats = await computeDatasetStatistics(ds);
expect(stats['number'].min).toEqual(0);
expect(stats['number'].max).toEqual(9);
expect(stats['number'].mean).toEqual(4.5);
expect(stats['number'].variance).toEqual(8.25);
expect(stats['number'].stddev).toEqual(2.8722813232690143);
// The TestDataset includes cubes of the indices
expect(stats['numberArray'].min).toEqual(0);
expect(stats['numberArray'].max).toEqual(729);
tf.test_util.expectNumbersClose(
stats['numberArray'].mean, 78.50000066757202);
tf.test_util.expectNumbersClose(
stats['numberArray'].variance, 26971.849851671857, 0.1);
tf.test_util.expectNumbersClose(
stats['numberArray'].stddev, 164.2310867396056);
expect(stats['Tensor'].min).toEqual(0);
expect(stats['Tensor'].max).toEqual(729);
tf.test_util.expectNumbersClose(stats['Tensor'].mean, 78.50000066757202);
tf.test_util.expectNumbersClose(
stats['Tensor'].variance, 26971.849851671857, 0.1);
tf.test_util.expectNumbersClose(
stats['Tensor'].stddev, 164.2310867396056);
});
it('variance should be NaN if there is only one element', async () => {
const ds = new TestDataset().take(1) as Dataset<TabularRecord>;
const stats = await computeDatasetStatistics(ds);
expect(stats['number'].min).toEqual(0);
expect(stats['number'].max).toEqual(0);
expect(stats['number'].mean).toEqual(0);
expect(stats['number'].variance).toEqual(NaN);
expect(stats['number'].stddev).toEqual(NaN);
});
});
describeAllEnvs('scaleTo01', () => {
it('scales numeric data to the [0, 1] interval', async () => {
const ds = new TestDataset().skip(55) as Dataset<TabularRecord>;
const scaleFn = scaleTo01(55, 99 * 99 * 99);
const scaledDataset = ds.map(x => ({'Tensor': scaleFn(x['Tensor'])}));
const stats = await computeDatasetStatistics(scaledDataset);
expect(stats['Tensor'].min).toBeCloseTo(0);
expect(stats['Tensor'].max).toBeCloseTo(1);
});
});