Skip to content

Commit

Permalink
reland "Do not use unsafe restriding for subclasses (pytorch#87610)" (p…
Browse files Browse the repository at this point in the history
…ytorch#88343)

This reverts commit 5b75b19.
Pull Request resolved: pytorch#88343
Approved by: https://github.com/ezyang
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Nov 14, 2022
1 parent 9943d46 commit ec4eada
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 44 deletions.
5 changes: 5 additions & 0 deletions aten/src/ATen/functorch/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,11 @@ Tensor index_copy_decomp(
return at::scatter(self, dim, index_, source); ;
}

// Note [Fix vmap slice_scatter]
// registers a decomposition for `slice_scatter` that calls into `slice.src`
// *_scatter operators have some special semantics though, that we can't easily
// through a decomposition: slice_scatter's output needs to have the same
// size, size, strides and storage_offset as the input.
Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src,
int64_t dim, c10::optional<int64_t> start,
c10::optional<int64_t> end, int64_t step)
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/core/DimVector.h>
#include <ATen/core/functional.h>
#include <ATen/core/IListRef.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
Expand Down Expand Up @@ -1573,7 +1574,7 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
//
// We need to do the checks here instead of in `native_functions.yaml`
// to preserve backwards compatibility.
if (!self.is_xla() && !self.is_lazy() && !self.is_ipu()) {
if (!self.is_xla() && !self.is_lazy() && !self.is_ipu() && !at::isTensorSubclassLike(self)) {
return self._reshape_alias_symint(shape, stride.value());
} else {
return self.view_symint(shape);
Expand Down
2 changes: 0 additions & 2 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,8 +1098,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
xfail('masked_fill', ''), # could not find kernel
xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposi...
xfail('masked.logsumexp', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
# Seems flaky: https://github.com/pytorch/pytorch/issues/88883
skip('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos...
Expand Down
10 changes: 7 additions & 3 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3130,13 +3130,16 @@ def normalize_devices(fx_g):
return fx_g

class TestFunctionalize(TestCase):
def _check_functionalize_correctness(self, f, inpt):
def _check_functionalize_correctness(self, f, inpt, *, skip_vmap=False):
inpt1 = inpt.clone()
inpt2 = inpt.clone()
inpt3 = inpt.clone()

expected_outputs = f(inpt1)
actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze()
if skip_vmap:
actual_outputs = functionalize(f)(inpt2)
else:
actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze()
# Right now the flavor of functionalize that also removes view ops
# isn't being used with vmap
# That's because {view}_copy ops don't have batching rules yet
Expand Down Expand Up @@ -3206,7 +3209,8 @@ def f(x: torch.Tensor) -> torch.Tensor:
z2, z3 = z1.split(2)
z2.add_(tmp)
return x
self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device))
# See Note [Fix vmap slice_scatter]
self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device), skip_vmap=True)

# Ensure functionalize works with List[Optional[Tensor]] arguments.
# See the fix / discussion at https://github.com/pytorch/pytorch/pull/76085
Expand Down
76 changes: 38 additions & 38 deletions test/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,17 @@ def forward(self, a_1):
sum_1 = torch.ops.aten.sum.default(relu)
ones_like = torch.ops.aten.ones_like.default(sum_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False, memory_format = torch.preserve_format); sum_1 = None
expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None
_reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(expand_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1]); expand_copy = None
new_empty_strided = torch.ops.aten.new_empty_strided.default(_reshape_alias_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
view_copy_3 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128])
view_copy_4 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128])
clone_1 = torch.ops.aten.clone.default(view_copy_4, memory_format = torch.contiguous_format); view_copy_4 = None
view_copy_3 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None
new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_3, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128])
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128])
clone_1 = torch.ops.aten.clone.default(view_copy_5, memory_format = torch.contiguous_format); view_copy_5 = None
threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0); clone_1 = relu = None
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(_reshape_alias_copy, [16, 64, 128, 128], [1048576, 16384, 128, 1]); _reshape_alias_copy = None
detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1); _reshape_alias_copy_1 = None
view_copy_5 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]); threshold_backward = None
_reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_5, [16, 64, 128, 128], [1048576, 16384, 128, 1]); view_copy_5 = None
detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_2); _reshape_alias_copy_2 = None
view_copy_6 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128]); view_copy_3 = None
detach_copy = torch.ops.aten.detach_copy.default(view_copy_6); view_copy_6 = None
view_copy_7 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]); threshold_backward = None
view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [16, 64, 128, 128]); view_copy_7 = None
detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_8); view_copy_8 = None
return detach_copy_1
""") # noqa: B950

Expand Down Expand Up @@ -710,40 +710,40 @@ def forward(self, a_1):
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view_copy = torch.ops.aten.view_copy.default(add, [8])
_reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(view_copy, [2, 4], [4, 1]); view_copy = None
transpose_copy = torch.ops.aten.transpose_copy.int(_reshape_alias_copy, 1, 0)
view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]); view_copy = None
transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0)
unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None
squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None
split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None
getitem = split_copy[0]
getitem_1 = split_copy[1]; split_copy = None
add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
select_copy = torch.ops.aten.select_copy.int(_reshape_alias_copy, 0, 0); _reshape_alias_copy = None
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(add_1, [4], [1])
view_copy_1 = torch.ops.aten.view_copy.default(add, [8]); add = None
_reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_2, 1, 0); _reshape_alias_copy_2 = None
select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = None
view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4])
view_copy_3 = torch.ops.aten.view_copy.default(add, [8]); add = None
view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [2, 4]); view_copy_3 = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_4, 1, 0); view_copy_4 = None
unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None
squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None
slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = None
unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None
squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None
transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None
_reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None
view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_3, [4, 2]); _reshape_alias_copy_3 = None
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8])
_reshape_alias_copy_4 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None
select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_4, 0, 0); _reshape_alias_copy_4 = None
view_copy_4 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None
_reshape_alias_copy_5 = torch.ops.aten._reshape_alias_copy.default(view_copy_4, [2, 4], [4, 1]); view_copy_4 = None
transpose_copy_3 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_5, 1, 0); _reshape_alias_copy_5 = None
view_copy_5 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]); transpose_copy_2 = None
view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [4, 2]); view_copy_5 = None
view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [8])
view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [2, 4]); view_copy_7 = None
select_copy_1 = torch.ops.aten.select_copy.int(view_copy_8, 0, 0); view_copy_8 = None
view_copy_9 = torch.ops.aten.view_copy.default(view_copy_6, [8]); view_copy_6 = None
view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None
transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_10, 1, 0); view_copy_10 = None
unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None
squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None
split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None
getitem_2 = split_copy_1[0]
getitem_3 = split_copy_1[1]; split_copy_1 = None
_reshape_alias_copy_6 = torch.ops.aten._reshape_alias_copy.default(getitem_2, [4], [1]); getitem_2 = None
add_2 = torch.ops.aten.add.Tensor(select_copy_1, _reshape_alias_copy_6); select_copy_1 = _reshape_alias_copy_6 = None
view_copy_11 = torch.ops.aten.view_copy.default(getitem_2, [4]); getitem_2 = None
add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_11); select_copy_1 = view_copy_11 = None
return add_1
""") # noqa: B950

Expand All @@ -756,30 +756,30 @@ def forward(self, a_1):
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view = torch.ops.aten.view.default(add, [8])
_reshape_alias = torch.ops.aten._reshape_alias.default(view, [2, 4], [4, 1]); view = None
transpose = torch.ops.aten.transpose.int(_reshape_alias, 1, 0)
view_1 = torch.ops.aten.view.default(view, [2, 4]); view = None
transpose = torch.ops.aten.transpose.int(view_1, 1, 0)
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None
squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None
split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None
getitem = split[0]
getitem_1 = split[1]; split = None
add_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None
select = torch.ops.aten.select.int(_reshape_alias, 0, 0); _reshape_alias = None
select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None
clone = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format)
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
view_1 = torch.ops.aten.view.default(add, [8]); add = None
_reshape_alias_1 = torch.ops.aten._reshape_alias.default(view_1, [2, 4], [4, 1]); view_1 = None
transpose_1 = torch.ops.aten.transpose.int(_reshape_alias_1, 1, 0); _reshape_alias_1 = None
view_2 = torch.ops.aten.view.default(add, [8]); add = None
view_3 = torch.ops.aten.view.default(view_2, [2, 4]); view_2 = None
transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0); view_3 = None
unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None
squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None
unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None
squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None
transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None
_reshape_alias_2 = torch.ops.aten._reshape_alias.default(transpose_2, [8], [1]); transpose_2 = None
view_2 = torch.ops.aten.view.default(_reshape_alias_2, [4, 2]); _reshape_alias_2 = None
view_3 = torch.ops.aten.view.default(view_2, [8]); view_2 = None
_reshape_alias_3 = torch.ops.aten._reshape_alias.default(view_3, [2, 4], [4, 1]); view_3 = None
select_1 = torch.ops.aten.select.int(_reshape_alias_3, 0, 0); _reshape_alias_3 = None
view_4 = torch.ops.aten.view.default(transpose_2, [8]); transpose_2 = None
view_5 = torch.ops.aten.view.default(view_4, [4, 2]); view_4 = None
view_6 = torch.ops.aten.view.default(view_5, [8]); view_5 = None
view_7 = torch.ops.aten.view.default(view_6, [2, 4]); view_6 = None
select_1 = torch.ops.aten.select.int(view_7, 0, 0); view_7 = None
add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None
return getitem
""")
Expand Down

0 comments on commit ec4eada

Please sign in to comment.