Skip to content

Commit

Permalink
SymInt support for computeStride (pytorch#84905)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#84905
Approved by: https://github.com/ezyang
  • Loading branch information
Krovatkin authored and pytorchmergebot committed Sep 13, 2022
1 parent 8b8141e commit 3f047b2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
30 changes: 19 additions & 11 deletions aten/src/ATen/TensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,12 @@ std::vector<int64_t> defaultStrides(IntArrayRef sizes) {
// templatized for DimVector and IntArrayRef use cases,
// see overloads of computeStride() below.
//
template <typename ResultVec, typename NewShapeVec>
template <typename ResultVec, typename NewShapeVec, typename Numel>
inline c10::optional<ResultVec> computeStride_impl(
IntArrayRef oldshape,
IntArrayRef oldstride,
const NewShapeVec& oldshape,
const NewShapeVec& oldstride,
const NewShapeVec& newshape,
ResultVec toResult(const IntArrayRef&)
ResultVec toResult(const NewShapeVec&)
) {
if (oldshape.empty()) {
return ResultVec(newshape.size(), 1);
Expand All @@ -326,7 +326,7 @@ inline c10::optional<ResultVec> computeStride_impl(
// we use the stride as if it were computed via resize.
// This could perhaps be combined with the below code, but the complexity
// didn't seem worth it.
const int64_t numel = c10::multiply_integers(oldshape);
const Numel numel = c10::multiply_integers(oldshape);
if (numel == 0 && oldshape.equals(newshape)) {
return toResult(oldstride);
}
Expand All @@ -338,18 +338,18 @@ inline c10::optional<ResultVec> computeStride_impl(
newstride[view_d] = 1;
} else {
newstride[view_d] =
std::max<int64_t>(newshape[view_d+1], 1) * newstride[view_d+1];
std::max<Numel>(newshape[view_d+1], Numel(1)) * newstride[view_d+1];
}
}
return newstride;
}

int64_t view_d = (int64_t)newshape.size() - 1;
// stride for each subspace in the chunk
int64_t chunk_base_stride = oldstride.back();
Numel chunk_base_stride = oldstride.back();
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
Numel tensor_numel = 1;
Numel view_numel = 1;
for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
tensor_numel *= oldshape[tensor_d];
// if end of tensor size chunk, check view
Expand Down Expand Up @@ -383,15 +383,23 @@ c10::optional<std::vector<int64_t>> computeStride(
IntArrayRef oldstride,
IntArrayRef newshape) {
auto toResult = [](const IntArrayRef& a) { return a.vec(); };
return computeStride_impl<std::vector<int64_t>, IntArrayRef>(oldshape, oldstride, newshape, toResult);
return computeStride_impl<std::vector<int64_t>, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult);
}

c10::optional<SymDimVector> computeStride(
c10::SymIntArrayRef oldshape,
c10::SymIntArrayRef oldstride,
c10::SymIntArrayRef newshape) {
auto toResult = [](const SymIntArrayRef& a) { return SymDimVector(a); };
return computeStride_impl<SymDimVector, c10::SymIntArrayRef, c10::SymInt>(oldshape, oldstride, newshape, toResult);
}

c10::optional<DimVector> computeStride(
IntArrayRef oldshape,
IntArrayRef oldstride,
const DimVector& newshape) {
auto toResult = [](const IntArrayRef& a) { return DimVector(a); };
return computeStride_impl<DimVector, DimVector>(oldshape, oldstride, newshape, toResult);
return computeStride_impl<DimVector, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult);
}

} // namespace detail
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/TensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ TORCH_API c10::optional<std::vector<int64_t>> computeStride(
IntArrayRef oldstride,
IntArrayRef newshape);

TORCH_API c10::optional<SymDimVector> computeStride(
c10::SymIntArrayRef oldshape,
c10::SymIntArrayRef oldstride,
c10::SymIntArrayRef newshape);

TORCH_API c10::optional<DimVector> computeStride(
IntArrayRef oldshape,
IntArrayRef oldstride,
Expand Down

0 comments on commit 3f047b2

Please sign in to comment.