Skip to content

Commit

Permalink
Implement batching rules for some view ops (pytorch#42248)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#42248

Including:
- torch.diagonal
- torch.t
- torch.select
- Tensor.expand_as
- Tensor slicing.

Please let me know in the future if it would be easier to review these
separately (I put five operators into this PR because each
implementation is relatively simple).

Test Plan:
- new tests in `test/test_vmap.py`.
- I would like to have a more structured/automated way of testing but
my previous attempts at making something resulted in something very
complicated.

Reviewed By: ezyang

Differential Revision: D22846273

Pulled By: zou3519

fbshipit-source-id: 8e45ebe11174512110faf1ee0fdc317a25e8b7ac
  • Loading branch information
zou3519 authored and facebook-github-bot committed Aug 3, 2020
1 parent 2f8d5b6 commit 4cdbe5c
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 2 deletions.
33 changes: 31 additions & 2 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,28 @@ Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) {
return self_physical.newLogicalFromPhysical(result);
}

Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = self_physical.tensor().select(dim_physical, index);
return self_physical.newLogicalFromPhysical(result);
}

Tensor slice_batching_rule(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = self_physical.tensor().slice(dim_physical, start, end, step);
return self_physical.newLogicalFromPhysical(result);
}

Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim1_physical = self_physical.getPhysicalDim(dim1);
auto dim2_physical = self_physical.getPhysicalDim(dim2);
auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical);
return self_physical.newLogicalFromPhysical(result);
}

TORCH_LIBRARY_IMPL(_, Batched, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
}
Expand All @@ -158,11 +180,18 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {

m.impl_UNBOXED("sum.dim_IntList", sum_batching_rule);
m.impl_UNBOXED("mul.Tensor", mul_batching_rule);

// view operations
m.impl("diagonal", diagonal_batching_rule);
m.impl("expand", expand_batching_rule);
m.impl("expand_as", native::expand_as); // composite wrt autograd
m.impl("permute", permute_batching_rule);
m.impl("select.int", select_batching_rule);
m.impl("slice.Tensor", slice_batching_rule);
m.impl("squeeze.dim", squeeze_dim_batching_rule);
m.impl("t", native::t); // composite wrt autograd
m.impl("transpose.int", transpose_int_batching_rule);
m.impl("unsqueeze", unsqueeze_batching_rule);
m.impl("squeeze.dim", squeeze_dim_batching_rule);
m.impl("permute", permute_batching_rule);
}

} // namespace at
95 changes: 95 additions & 0 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,5 +499,100 @@ def run_test(batch_size):
run_test(batch_size=1237)


def slice_inputs(inputs, bdims, i):
result = []
for inp, bdim in zip(inputs, bdims):
if bdim is None:
result.append(inp)
else:
result.append(inp.select(bdim, i))
return tuple(result)


def reference_vmap(op, inputs, in_dims=0, out_dims=0):
if isinstance(in_dims, int):
in_dims = (in_dims,) * len(inputs)
bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None]
assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes)
bdim_size = bdim_sizes[0]
results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size))
# reference_vmap only supports functions that return a single Tensor output
assert all(isinstance(result, torch.Tensor) for result in results)
if isinstance(out_dims, int):
out_dims = (out_dims,) * 1
return torch.stack(results, dim=out_dims[0])


class TestVmapOperators(TestCase):
def _vmap_view_test(self, op, inputs, in_dims=0, out_dims=0):
result = vmap(op, in_dims, out_dims)(*inputs)
reference_result = reference_vmap(op, inputs, in_dims, out_dims)
self.assertEqual(result, reference_result)
self.assertEqual(result.data_ptr() - result.storage_offset() * result.element_size(),
inputs[0].data_ptr(),
msg="result was not a view of the first input!")

# Assuming input[0] is a floating-point tensor. Check if the vmap
# operation propagates the requires_grad flag. Some vmap operators are
# implemented in a way that assumes that they are composite with respect
# to autograd. If the operator ever is changed to not be composite with
# respect to autograd, then the following check should fail.
inputs_clone = list(inputs)
inputs_clone[0] = inputs[0].clone().requires_grad_()
result = vmap(op, in_dims, out_dims)(*inputs_clone)
self.assertTrue(result.requires_grad)

def test_diagonal(self):
tensor = torch.randn(3, 5, 7, 11, 13)
test = self._vmap_view_test
op = torch.diagonal
test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None))
test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None))
test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None))
test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1)
test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1)
test(vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
(tensor,), in_dims=1, out_dims=1)

def test_expand_as(self):
op = torch.Tensor.expand_as
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5)))
test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None))
test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5)))
test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), in_dims=(0, 1))
test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))

def test_select(self):
op = torch.select
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None))
test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None))
test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
test(vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2)

def test_slice(self):
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
test(lambda t: t[0:1], (torch.rand(B0, 3, 5),))
test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2)
test(vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2)
test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
(torch.rand(3, 5, B0, B1, B2),), in_dims=2)

def test_t(self):
op = torch.t
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
test(op, (torch.rand(B0, 2, 5),))
test(op, (torch.rand(2, B0, 5),), in_dims=1)
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)


if __name__ == '__main__':
run_tests()

0 comments on commit 4cdbe5c

Please sign in to comment.