Skip to content

Commit

Permalink
Revert "Use register_meta for everything in meta_registrations (pytor…
Browse files Browse the repository at this point in the history
…ch#84297)"

This reverts commit 8cd296f.

Reverted pytorch#84297 on behalf of https://github.com/suo due to broke test_proxy_tensor on master
  • Loading branch information
pytorchmergebot committed Aug 31, 2022
1 parent bf67589 commit 14093b5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
3 changes: 3 additions & 0 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,7 @@ def f(a, b):
symbolic_tensor_failures = {
# Needs complex-value support
xfail('polar'),
xfail('complex'),
xfail('linalg.eig'),
xfail('__getitem__', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('__rmatmul__', ''), # aten.new_empty.default - couldn't find symbolic meta function/decomposition
Expand Down Expand Up @@ -1045,6 +1046,7 @@ def f(a, b):
xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.vecdot', ''), # Could not run 'aten::vdot' with arguments from the 'Meta' backend. This could be ...
xfail('linalg.vector_norm', ''), # TensorImpl do not have numel
xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
xfail('logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
Expand Down Expand Up @@ -1241,6 +1243,7 @@ def f(a, b):
xfail('unfold', ''), # aten.unfold.default - couldn't find symbolic meta function/decomposition
xfail('var_mean', ''), # Unexpected type <class 'torch.SymIntNode'> when computing elementwise type promotion!
xfail('var', ''), # Unexpected type <class 'torch.SymIntNode'> when computing elementwise type promotion!
xfail('vdot', ''), # aten.vdot.default - couldn't find symbolic meta function/decomposition
xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition
xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
Expand Down
19 changes: 10 additions & 9 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

aten = torch.ops.aten

_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
meta_lib = torch.library.Library("aten", "IMPL", "Meta")

meta_table = {}

Expand All @@ -32,7 +32,7 @@ def add_func(op):
if op._overloadname != "default"
else op.overloadpacket.__name__
)
_meta_lib_dont_use_me_use_register_meta.impl(name, f)
meta_lib.impl(name, f)

tree_map(add_func, op)
return f
Expand Down Expand Up @@ -195,7 +195,7 @@ def _compute_reduction_shape(self, dims, keepdim):
return utils.compute_reduction_output_shape(self.shape, dims)


@register_meta(aten.bernoulli.out)
@torch.library.impl(meta_lib, "bernoulli.out")
def meta_bernoulli(self, *, generator=None, out):
torch._resize_output_(out, self.size(), self.device)
return out
Expand Down Expand Up @@ -380,7 +380,8 @@ def meta_repeat_interleave_Tensor(repeats, output_size=None):
return repeats.new_empty(output_size)


@register_meta([aten.complex.default, aten.complex.out])
@torch.library.impl(meta_lib, "complex")
@torch.library.impl(meta_lib, "complex.out")
@out_wrapper()
def meta_complex(real, imag):
assert real.dtype.is_floating_point
Expand All @@ -389,7 +390,7 @@ def meta_complex(real, imag):
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))


@register_meta(aten.vdot.default)
@torch.library.impl(meta_lib, "vdot")
def vdot(self, other):
if not self.is_complex:
return torch.dot(self, other)
Expand Down Expand Up @@ -538,7 +539,7 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
return self.new_empty(self.size())


@register_meta(aten._cdist_forward.default)
@torch.library.impl(meta_lib, "_cdist_forward")
def meta_cdist_forward(x1, x2, p, compute_mode):
check(
x1.dim() >= 2,
Expand Down Expand Up @@ -574,7 +575,7 @@ def meta_cdist_forward(x1, x2, p, compute_mode):
return x1.new_empty(output_shape)


@register_meta(aten._embedding_bag.default)
@torch.library.impl(meta_lib, "_embedding_bag")
def meta_embedding_bag(
weight,
indices,
Expand Down Expand Up @@ -683,7 +684,7 @@ def meta_diag(self, dim=0):
return self.new_empty((sz,))


@register_meta(aten._embedding_bag_forward_only.default)
@torch.library.impl(meta_lib, "_embedding_bag_forward_only")
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
output, offset2bag, bag_size, max_indices = meta_embedding_bag(
weight, indices, offsets, *args
Expand Down Expand Up @@ -734,7 +735,7 @@ def meta_nanmedian_dim(input, dim=-1, keepdim=False):
)


@register_meta(aten.logical_not_.default)
@torch.library.impl(meta_lib, "logical_not_")
def meta_logical_not_(self):
return self

Expand Down

0 comments on commit 14093b5

Please sign in to comment.