Skip to content

Commit

Permalink
[PyTorch] Add & use inferExpandGeometry_dimvector (pytorch#55316)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#55316

No need for heap allocations in the common case here.
ghstack-source-id: 126170054

Test Plan: Existing CI

Reviewed By: hlu1

Differential Revision: D27571942

fbshipit-source-id: 11fbf077c583c80ea63e024d2b9e1599785fff71
  • Loading branch information
swolchok authored and facebook-github-bot committed Apr 10, 2021
1 parent 151869a commit 548765d
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 15 deletions.
33 changes: 25 additions & 8 deletions aten/src/ATen/ExpandUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,22 @@ DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b) {
return infer_size_impl<DimVector>(a, b);
}

std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
template <typename Container>
std::tuple<Container, Container> inferExpandGeometryImpl(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides,
IntArrayRef sizes) {
int64_t ndim = sizes.size();
int64_t tensor_dim = tensor_sizes.size();

if (tensor_dim == 0) {
std::vector<int64_t> expandedStrides(ndim, 0);
return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(
sizes.vec(), expandedStrides);
return std::make_tuple(
Container(sizes.begin(), sizes.end()), Container(ndim, 0));
}
std::vector<int64_t> expandedSizes(ndim);
std::vector<int64_t> expandedStrides(ndim);

std::tuple<Container, Container> result{Container(ndim), Container(ndim)};
auto& expandedSizes = std::get<0>(result);
auto& expandedStrides = std::get<1>(result);

// create a new geometry for the tensors
for (int64_t i = ndim - 1; i >= 0; --i) {
Expand Down Expand Up @@ -94,8 +96,23 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
expandedSizes[i] = size;
expandedStrides[i] = stride;
}
return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(
expandedSizes, expandedStrides);
return result;
}

std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides,
IntArrayRef sizes) {
return inferExpandGeometryImpl<std::vector<int64_t>>(
tensor_sizes, tensor_strides, sizes);
}

std::tuple<DimVector, DimVector> inferExpandGeometry_dimvector(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides,
IntArrayRef sizes) {
return inferExpandGeometryImpl<DimVector>(
tensor_sizes, tensor_strides, sizes);
}


Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/ExpandUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ inferExpandGeometry(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides,
IntArrayRef sizes);
TORCH_API std::tuple<DimVector, DimVector>
inferExpandGeometry_dimvector(
IntArrayRef tensor_sizes,
IntArrayRef tensor_strides,
IntArrayRef sizes);

TORCH_API std::vector<int64_t> infer_dense_strides(
IntArrayRef tensor_sizes,
Expand Down
7 changes: 3 additions & 4 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -734,11 +734,10 @@ Tensor expand(const Tensor& self, IntArrayRef size, bool implicit) {
"must be greater or equal to the number of dimensions in the tensor (",
self.dim(), ")");

std::vector<int64_t> expandedSizes;
std::vector<int64_t> expandedStrides;
std::tie(expandedSizes, expandedStrides) = inferExpandGeometry(self.sizes(), self.strides(), size);
auto expandedSizesAndStrides = inferExpandGeometry_dimvector(self.sizes(), self.strides(), size);

auto result = self.as_strided(expandedSizes, expandedStrides);
auto result = self.as_strided(
std::get<0>(expandedSizesAndStrides), std::get<1>(expandedSizesAndStrides));
namedinference::propagate_names_for_expand(result, self);
return result;
}
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/passes/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2110,12 +2110,12 @@ class ShapePropagator {
"aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
/*const_inputs=*/attr::size)) {
auto tp = tensor_types.at(0);
std::vector<int64_t> sizes, strides;
std::tie(sizes, strides) = at::inferExpandGeometry(
auto sizesAndStrides = at::inferExpandGeometry_dimvector(
tp->sizes().concrete_sizes().value(),
tp->strides().concrete_sizes().value(),
node->get<c10::List<int64_t>>(attr::size).value().vec());
node->output()->setType(tp->withSizesStrides(sizes, strides));
node->output()->setType(tp->withSizesStrides(
std::get<0>(sizesAndStrides), std::get<1>(sizesAndStrides)));
return true;
} else if (
node->matches(
Expand Down

0 comments on commit 548765d

Please sign in to comment.