Skip to content

Commit

Permalink
Require shape in tensor2d/3d/4d(flatValues, shape) (tensorflow#969)
Browse files Browse the repository at this point in the history
**Bug**
`tensor2d([1, 2, 3, 4])` returns Tensor of rank 1 with shape `[4]`. Likewise for `tensor3d` and `tensor4d`

**Solution**
`tensor2d(flatValues)` throws an error requiring shape to be provided explicitly by the user.
  • Loading branch information
dsmilkov authored Apr 18, 2018
1 parent 8a01889 commit 87ae277
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/ops/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ export class ArrayOps {
static tensor1d(values: TensorLike1D, dtype: DataType = 'float32'): Tensor1D {
const inferredShape = util.inferShape(values);
if (inferredShape.length !== 1) {
throw new Error(
'Error creating a new Tensor1D: values must be a flat/TypedArray');
throw new Error('tensor1d() requires values to be a flat/TypedArray');
}
return ArrayOps.tensor(values, inferredShape as [number], dtype);
}
Expand Down Expand Up @@ -151,8 +150,12 @@ export class ArrayOps {
const inferredShape = util.inferShape(values);
if (inferredShape.length !== 2 && inferredShape.length !== 1) {
throw new Error(
'Error creating a new Tensor2D: values must be number[][] ' +
'or flat/TypedArray');
'tensor2d() requires values to be number[][] or flat/TypedArray');
}
if (inferredShape.length === 1 && shape == null) {
throw new Error(
'tensor2d() requires shape to be provided when `values` ' +
'are a flat/TypedArray');
}
shape = shape || inferredShape as [number, number];
return ArrayOps.tensor(values, shape, dtype);
Expand Down Expand Up @@ -186,8 +189,12 @@ export class ArrayOps {
const inferredShape = util.inferShape(values);
if (inferredShape.length !== 3 && inferredShape.length !== 1) {
throw new Error(
'Error creating a new Tensor3D: values must be number[][][]' +
'or flat/TypedArray');
'tensor3d() requires values to be number[][][] or flat/TypedArray');
}
if (inferredShape.length === 1 && shape == null) {
throw new Error(
'tensor3d() requires shape to be provided when `values` ' +
'are a flat array');
}
shape = shape || inferredShape as [number, number, number];
return ArrayOps.tensor(values, shape, dtype);
Expand Down Expand Up @@ -221,8 +228,12 @@ export class ArrayOps {
const inferredShape = util.inferShape(values);
if (inferredShape.length !== 4 && inferredShape.length !== 1) {
throw new Error(
'Error creating a new Tensor4D: values must be number[][][][]' +
'or flat/TypedArray');
'tensor4d() requires values to be number[][][][] or flat/TypedArray');
}
if (inferredShape.length === 1 && shape == null) {
throw new Error(
'tensor4d() requires shape to be provided when `values` ' +
'are a flat array');
}
shape = shape || inferredShape as [number, number, number, number];
return ArrayOps.tensor(values, shape, dtype);
Expand Down
12 changes: 12 additions & 0 deletions src/tensor_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expect(() => tf.tensor2d([[1, 2, 3], [4, 5, 6]], [3, 2])).toThrowError();
});

it('tf.tensor2d() from number[], but no shape throws error', () => {
expect(() => tf.tensor2d([1, 2, 3, 4])).toThrowError();
});

it('tensor3d() from number[][][]', () => {
const a = tf.tensor3d([[[1], [2], [3]], [[4], [5], [6]]], [2, 3, 1]);
expectArraysClose(a, [1, 2, 3, 4, 5, 6]);
Expand All @@ -296,6 +300,10 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expect(() => tf.tensor3d(values, [3, 2, 1])).toThrowError();
});

it('tf.tensor3d() from number[], but no shape throws error', () => {
expect(() => tf.tensor3d([1, 2, 3, 4])).toThrowError();
});

it('tensor4d() from number[][][][]', () => {
const a = tf.tensor4d([[[[1]], [[2]]], [[[4]], [[5]]]], [2, 2, 1, 1]);
expectArraysClose(a, [1, 2, 4, 5]);
Expand All @@ -309,6 +317,10 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expect(f).toThrowError();
});

it('tf.tensor4d() from number[], but no shape throws error', () => {
expect(() => tf.tensor4d([1, 2, 3, 4])).toThrowError();
});

it('default dtype', () => {
const a = tf.scalar(3);
expect(a.dtype).toBe('float32');
Expand Down

0 comments on commit 87ae277

Please sign in to comment.