Skip to content

Commit

Permalink
[functorch] Fix torch.cat batching rule (pytorch#86932)
Browse files Browse the repository at this point in the history
The bug was discovered in pytorch#86842.

torch.cat has an edge case where it ignores all tensors of shape [0]. So
if any of the BatchedTensors have logical shape [0] but physical shape
[B, 0], then we coerce them to shape [0] by slicing them.

Why don't we just ignore those Tensors? We need to propagate
requires_grad-ness somehow (e.g. if the BatchedTensor wraps a Tensor of
shape [B, 0] that requires grad, then the output must require grad).

Test Plan:
- new tests
Pull Request resolved: pytorch#86932
Approved by: https://github.com/Chillee
  • Loading branch information
zou3519 authored and pytorchmergebot committed Oct 20, 2022
1 parent c16b7b4 commit b805e1a
Showing 4 changed files with 75 additions and 17 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
@@ -129,6 +129,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("index_select_backward", native::index_select_backward_symint);
OP_DECOMPOSE(inner);
OP_DECOMPOSE(inverse);
OP_DECOMPOSE(concatenate);
OP_DECOMPOSE(instance_norm);
OP_DECOMPOSE(kron);
OP_DECOMPOSE(l1_loss);
78 changes: 68 additions & 10 deletions aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
#include <torch/library.h>
#include <ATen/native/ResizeCommon.h>
#include <ATen/ATen.h>
#include <ATen/native/TensorShape.h>

#include <ATen/functorch/DynamicLayer.h>
#include <ATen/functorch/TensorWrapper.h>
@@ -68,11 +69,15 @@ static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
return dim == 0 || dim == -1;
}

// This check should probably go into the dispatcher...
static bool participatesInCurrentLevel(const Tensor& self) {
static int64_t get_current_level() {
auto maybe_level = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_level.has_value());
auto current_level = maybe_level->layerId();
return maybe_level->layerId();
}

// This check should probably go into the dispatcher...
static bool participatesInCurrentLevel(const Tensor& self) {
auto current_level = get_current_level();
auto* maybe_batched_impl = maybeGetBatchedImpl(self);
if (!maybe_batched_impl) {
return false;
@@ -611,13 +616,66 @@ Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::cat(tensors, dim);
}
auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
auto physical_tensors = fmap(
physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
TORCH_INTERNAL_ASSERT(
tensors.size() > 0, "The dispatcher should not have dispatched here otherwise.");
auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim));
return physical_views[0].getPhysicalToLogicalMap().apply(result);

c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);

// NB: Probably bad for perf that we're allocating std::vectors for each level, but
// what can you do.
auto materialized = tensors.materialize();
dim = at::legacy_cat_wrap_dim(dim, materialized);

// Strategy:
// we're going to unwrap tensors, move their batch dims to the front,
// and put them into `tensors_to_cat`. Tensors that don't have a batch dim
// will get one forced onto them.
//
// Then, we'll do at::cat(tensors_to_cat, ...).
//
// There's a special case where at::cat ignores tensors that have logical shape
// [0]. If we see a Tensor that has logical shape [0] (but physical shape [B, 0]),
// we'll just slice the tensor to get a Tensor of shape [0] to pass to at::cat.
std::vector<Tensor> tensors_to_cat;
tensors_to_cat.reserve(tensors.size());
c10::optional<int64_t> bdim_size = c10::nullopt;

// find the bdim size. Might not exist if all BatchedTensors should be skipped
// by cat's special case.
for (const auto& tensor : tensors) {
if (!participatesInCurrentLevel(tensor)) {
continue;
}
if (at::native::cat_should_skip_tensor(tensor)) {
continue;
}
const auto* batched = unsafeGetBatchedImpl(tensor);
bdim_size = batched->value().size(batched->bdim());
break;
}

// unwrap batchedtensors; expand out bdims
for (const auto& tensor : tensors) {
if (!participatesInCurrentLevel(tensor)) {
if (at::native::cat_should_skip_tensor(tensor) || !bdim_size.has_value()) {
tensors_to_cat.emplace_back(tensor);
continue;
}
tensors_to_cat.emplace_back(ensure_has_bdim(tensor, /*has_bdim*/false, *bdim_size));
continue;
}
const auto* batched = unsafeGetBatchedImpl(tensor);
if (at::native::cat_should_skip_tensor(tensor)) {
// Special case: slice the tensor to get something of shape [0] to pass to cat
// We slice instead of allocate a new tensor to propagate requires_gradness...
tensors_to_cat.emplace_back(batched->value().select(/*dim=*/batched->bdim(), /*index=*/0));
continue;
}
tensors_to_cat.emplace_back(moveBatchDimToFront(batched->value(), batched->bdim()));
}

auto new_dim = bdim_size.has_value() ? dim + 1 : dim;
c10::optional<int64_t> new_bdim = bdim_size.has_value() ? c10::make_optional((int64_t)0) : nullopt;
auto result = at::cat(tensors_to_cat, new_dim);
return makeBatched(result, new_bdim, get_current_level());
}

Tensor block_diag_batching_rule(TensorList tensors) {
5 changes: 0 additions & 5 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
@@ -654,7 +654,6 @@ def fn(inp, *args, **kwargs):
# view doesn't work on sparse
xfail("to_sparse"),
xfail("native_batch_norm"),
xfail("cat"), # improper handling for cat empty tensor with non-empty (new test exposed pre-existing bug)
}))
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@@ -722,7 +721,6 @@ def vjp_of_vjp(*args_and_cotangents):
skip('nn.functional.dropout2d'), # randomness
skip('nn.functional.dropout3d', ''), # randomness
skip('nn.functional._scaled_dot_product_attention'), # randomness
xfail("cat"), # improper handling for cat empty tensor with non-empty (new test exposed pre-existing bug)
xfail('as_strided'), # as_strided is too wild for us to support, wontfix
xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset
xfail('masked_scatter'), # dynamic
@@ -853,7 +851,6 @@ def test_vmapvjp(self, device, dtype, op):
xfail('nn.functional.batch_norm', 'without_cudnn'),
xfail("native_batch_norm"),
# ----------------------------------------------------------------------
xfail("cat"), # improper handling for cat empty tensor with non-empty (new test exposed pre-existing bug)
}

@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@@ -1124,7 +1121,6 @@ def test():
xfail('as_strided_scatter', ''),
xfail('sparse.sampled_addmm', ''),
xfail("native_batch_norm"),
xfail("cat"), # improper handling for cat empty tensor with non-empty (new test exposed pre-existing bug)
}))
def test_vjpvmap(self, device, dtype, op):
# NB: there is no vjpvmap_has_batch_rule test because that is almost
@@ -1377,7 +1373,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
# input while the running_mean or running_var, which will be updated in
# place, were not batched.
xfail("native_batch_norm"),
xfail("cat"), # improper handling for cat empty tensor with non-empty (new test exposed pre-existing bug)
}))
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
8 changes: 6 additions & 2 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
@@ -1663,6 +1663,12 @@ def op(*tensors):
return op

test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3)))
test(get_op(0), (torch.rand(B0, 0), torch.rand(B0, 0)))
test(get_op(0), (torch.rand(2), torch.rand(B0, 0)), in_dims=(None, 0))
test(get_op(1), (torch.rand(2, 5), torch.rand(B0, 0), torch.rand(2, 3)), in_dims=(None, 0, None))
test(get_op(1), (torch.rand(B0, 2, 3), torch.rand(B0, 0)))
test(get_op(1), (torch.rand(B0, 2, 3, 4), torch.rand(0)), in_dims=(0, None))
test(get_op(0), (torch.rand(0), torch.rand(B0, 2), torch.rand(B0, 0)), in_dims=(None, 0, 0))
test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0))
test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2))
test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2))
@@ -3286,7 +3292,6 @@ def test():
))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({
xfail('cat'),
xfail('native_batch_norm'),
}))
def test_vmap_exhaustive(self, device, dtype, op):
@@ -3304,7 +3309,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('cat'),
xfail('complex'),
xfail('copysign'),
xfail('native_batch_norm'),

0 comments on commit b805e1a

Please sign in to comment.