forked from tensorflow/tfjs-core
-
Notifications
You must be signed in to change notification settings - Fork 0
/
slice.ts
162 lines (154 loc) · 5.68 KB
/
slice.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {doc} from '../doc';
import {ENV} from '../environment';
import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';
import {Rank} from '../types';
import * as util from '../util';
import {operation} from './operation';
import * as slice_util from './slice_util';
export class SliceOps {
/**
* Extracts a 1D slice from 1D array starting at coordinates `begin` and is
* of length `size`. See `slice` for details.
*/
static slice1d(x: Tensor1D, begin: number, size: number): Tensor1D {
util.assert(
x.rank === 1,
`slice1d expects a rank-1 tensor, but got a rank-${x.rank} tensor`);
return SliceOps.slice(x, [begin], [size]);
}
/**
* Extracts a 2D slice from a 2D array starting at coordinates `begin` and
* is of size `size`. See `slice` for details.
*/
static slice2d(x: Tensor2D, begin: [number, number], size: [number, number]):
Tensor2D {
util.assert(
x.rank === 2,
`slice1d expects a rank-2 tensor, but got a rank-${x.rank} tensor`);
return SliceOps.slice(x, begin, size);
}
/**
* Extracts a 3D slice from a 3D array starting at coordinates `begin` and
* is of size `size`. See `slice` for details.
*/
static slice3d(x: Tensor3D, begin: [number, number, number], size: [
number, number, number
]): Tensor3D {
util.assert(
x.rank === 3,
`slice1d expects a rank-3 tensor, but got a rank-${x.rank} tensor`);
return SliceOps.slice(x, begin, size);
}
/**
* Extracts a 4D slice from a 4D array starting at coordinates `begin` and
* is of size `size`. See `slice` for details.
*/
static slice4d(x: Tensor4D, begin: [number, number, number, number], size: [
number, number, number, number
]): Tensor4D {
util.assert(
x.rank === 4,
`slice1d expects a rank-4 tensor, but got a rank-${x.rank} tensor`);
return SliceOps.slice(x, begin, size);
}
/**
* Extracts a slice from a `Tensor` starting at coordinates `begin`
* and is of size `size`.
*
* Also available are stricter rank-specific methods with the same signature
* as this method that assert that `x` is of the given rank:
* - `tf.slice1d`
* - `tf.slice2d`
* - `tf.slice3d`
* - `tf.slice4d`
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
*
* x.slice([1], [2]).print();
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* x.slice([1, 0], [1, 2]).print();
* ```
* @param x The input `Tensor` to slice from.
* @param begin The coordinates to start the slice from. The length can be
* less than the rank of x - the rest of the axes will have implicit 0 as
* start. Can also be a single number, in which case it specifies the
* first axis.
* @param size The size of the slice. The length can be less than the rank of
* x - the rest of the axes will have implicit -1. A value of -1 requests
* the rest of the dimensions in the axis. Can also be a single number,
* in which case it specifies the size of the first axis.
*/
@doc({heading: 'Tensors', subheading: 'Slicing and Joining'})
@operation
static slice<R extends Rank, T extends Tensor<R>>(
x: T, begin: number|number[], size?: number|number[]): T {
util.assertArgumentsAreTensors({x}, 'slice');
if (x.rank === 0) {
throw new Error('Slicing scalar is not possible');
}
// The following logic allows for more ergonomic calls.
let begin_: number[];
if (typeof begin === 'number') {
begin_ = [begin, ...new Array(x.rank - 1).fill(0)];
} else if (begin.length < x.rank) {
begin_ = begin.concat(new Array(x.rank - begin.length).fill(0));
} else {
begin_ = begin;
}
let size_: number[];
if (size == null) {
size_ = new Array(x.rank).fill(-1);
} else if (typeof size === 'number') {
size_ = [size, ...new Array(x.rank - 1).fill(-1)];
} else if (size.length < x.rank) {
size_ = size.concat(new Array(x.rank - size.length).fill(-1));
} else {
size_ = size;
}
size_ = size_.map((d, i) => {
if (d >= 0) {
return d;
} else {
util.assert(d === -1, 'Bad value in size');
return x.shape[i] - begin_[i];
}
});
slice_util.assertParamsValid(x, begin_, size_);
const inputShape = x.shape;
const grad = (dy: T) => {
// Create an Nx2 padding where the first column represents how many
// zeros are prepended (at start) for each dimension, and the second
// column indicates how many zeros are appended (at end).
// The number of zeros to append is the shape of the input
// elementwise-subtracted by both the begin vector and sizes vector.
const paddings: Array<[number, number]> = [];
for (let i = 0; i < dy.rank; i++) {
paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
}
return {x: () => dy.pad(paddings)};
};
return ENV.engine.runKernel(
backend => backend.slice(x, begin_, size_), {x}, grad) as T;
}
}