forked from tensorflow/tfjs-core
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lrn.ts
69 lines (66 loc) · 2.57 KB
/
lrn.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {doc} from '../doc';
import {ENV} from '../environment';
import {Tensor3D, Tensor4D} from '../tensor';
import * as util from '../util';
import {operation} from './operation';
export class LRNOps {
/**
* Normalizes the activation of a local neighborhood across or within
* channels.
*
* @param x The input tensor. The 4-D input tensor is treated as a 3-D array
* of 1D vectors (along the last dimension), and each vector is
* normalized independently.
* @param radius The number of adjacent channels or spatial locations of the
* 1D normalization window. In Tensorflow this param is called
* 'depth_radius' because only 'acrossChannels' mode is supported.
* @param bias A constant bias term for the basis.
* @param alpha A scale factor, usually positive.
* @param beta An exponent.
*/
@doc({heading: 'Operations', subheading: 'Normalization'})
@operation
static localResponseNormalization<T extends Tensor3D|Tensor4D>(
x: T, radius = 5, bias = 1, alpha = 1, beta = 0.5): T {
util.assertArgumentsAreTensors({x}, 'localResponseNormalization');
util.assert(
x.rank === 4 || x.rank === 3,
`Error in localResponseNormalization: x must be rank 3 or 4 but got
rank ${x.rank}.`);
util.assert(
util.isInt(radius),
`Error in localResponseNormalization3D: radius must be an integer
but got radius ${radius}.`);
let x4D = x as Tensor4D;
let reshapedTo4D = false;
if (x.rank === 3) {
reshapedTo4D = true;
x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]);
}
const res = ENV.engine.runKernel(
backend => backend.localResponseNormalization4D(
x4D, radius, bias, alpha, beta),
{x4D});
if (reshapedTo4D) {
return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as T;
} else {
return res as T;
}
}
}