Skip to content

Commit

Permalink
Deprecate torch.lu
Browse files Browse the repository at this point in the history
**BC-breaking note**:

This PR deprecates `torch.lu` in favor of `torch.linalg.lu_factor`.
A upgrade guide is added to the documentation for `torch.lu`.

Note this PR DOES NOT remove `torch.lu`.

Pull Request resolved: pytorch#77636

Approved by: https://github.com/malfet
  • Loading branch information
lezcano authored and pytorchmergebot committed Jun 7, 2022
1 parent f091b3f commit f7b9a46
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 41 deletions.
13 changes: 12 additions & 1 deletion aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,7 @@ bool _requires_fw_or_bw_grad(const Tensor& input) {
// Below of the definitions of the functions operating on a batch that are going to be dispatched
// in the main helper functions for the linear algebra operations

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// Solves a system of linear equations matmul(input, x) = other in-place
// LAPACK/MAGMA error codes are saved in 'infos' tensor, they are not checked here
Expand Down Expand Up @@ -2073,6 +2073,17 @@ std::tuple<Tensor, Tensor> linalg_lu_factor(const Tensor& A, bool pivot) {

// TODO Deprecate this function in favour of linalg_lu_factor_ex
std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool) {
TORCH_WARN_ONCE(
"torch.lu is deprecated in favor of torch.linalg.lu_factor / torch.linalg.lu_factor_ex and will be ",
"removed in a future PyTorch release.\n",
"LU, pivots = torch.lu(A, compute_pivots)\n",
"should be replaced with\n",
"LU, pivots = torch.linalg.lu_factor(A, compute_pivots)\n",
"and\n",
"LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)\n",
"should be replaced with\n",
"LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)"
);
return at::linalg_lu_factor_ex(self, compute_pivots, false);
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1920,7 +1920,7 @@ static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& piv
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(
false,
"Calling torch.lu on a CUDA tensor requires compiling ",
"Calling linalg.lu_factor on a CUDA tensor requires compiling ",
"PyTorch with MAGMA. Please rebuild with MAGMA.");
#else
auto input_data = input.data_ptr<scalar_t>();
Expand Down
6 changes: 3 additions & 3 deletions test/mobile/model_test/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ def blas_lapack_ops(self):
# torch.logdet(m),
# torch.slogdet(m),
# torch.lstsq(m, m),
# torch.lu(m),
# torch.lu_solve(m, *torch.lu(m)),
# torch.lu_unpack(*torch.lu(m)),
# torch.linalg.lu_factor(m),
# torch.lu_solve(m, *torch.linalg.lu_factor(m)),
# torch.lu_unpack(*torch.linalg.lu_factor(m)),
torch.matmul(m, m),
torch.matrix_power(m, 2),
# torch.matrix_rank(m),
Expand Down
14 changes: 1 addition & 13 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9150,20 +9150,8 @@ def istft(input, n_fft):
inps2 = (stft(*inps), inps[1])
self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2))

def lu(x):
# type: (Tensor) -> Tuple[Tensor, Tensor]
return torch.lu(x)

self.checkScript(lu, (torch.randn(2, 3, 3),))

def lu_infos(x):
# type: (Tensor) -> Tuple[Tensor, Tensor, Tensor]
return torch.lu(x, get_infos=True)

self.checkScript(lu_infos, (torch.randn(2, 3, 3),))

def lu_unpack(x):
A_LU, pivots = torch.lu(x)
A_LU, pivots = torch.linalg.lu_factor(x)
return torch.lu_unpack(A_LU, pivots)

for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)):
Expand Down
19 changes: 5 additions & 14 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4213,7 +4213,7 @@ def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_c
size_b = size_b[1:]

if well_conditioned:
PLU = torch.lu_unpack(*torch.lu(make_randn(*size_a)))
PLU = torch.linalg.lu(make_randn(*size_a))
if uni:
# A = L from PLU
A = PLU[1].transpose(-2, -1).contiguous()
Expand Down Expand Up @@ -4900,15 +4900,6 @@ def call_torch_fn(*args, **kwargs):
self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True))

if torch._C.has_lapack:
# lu
A_LU, pivots = fn(torch.lu, (0, 5, 5))
self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape])
A_LU, pivots = fn(torch.lu, (0, 0, 0))
self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape])
A_LU, pivots = fn(torch.lu, (2, 0, 0))
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])

@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*floating_and_complex_types_and(
Expand Down Expand Up @@ -5276,7 +5267,7 @@ def gen_matrices():
@dtypes(torch.double)
def test_lu_unpack_check_input(self, device, dtype):
x = torch.rand(5, 5, 5, device=device, dtype=dtype)
lu_data, lu_pivots = torch.lu(x, pivot=True)
lu_data, lu_pivots = torch.linalg.lu_factor(x)

with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"):
torch.lu_unpack(lu_data, lu_pivots.long())
Expand Down Expand Up @@ -7163,7 +7154,7 @@ def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):

b = torch.randn(*b_dims, dtype=dtype, device=device)
A = make_A(*A_dims)
LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot)
LU_data, LU_pivots, info = torch.linalg.lu_factor_ex(A)
self.assertEqual(info, torch.zeros_like(info))
return b, A, LU_data, LU_pivots

Expand Down Expand Up @@ -7207,7 +7198,7 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
# Tests tensors with 0 elements
b = torch.randn(3, 0, 3, dtype=dtype, device=device)
A = torch.randn(3, 0, 0, dtype=dtype, device=device)
LU_data, LU_pivots = torch.lu(A)
LU_data, LU_pivots = torch.linalg.lu_factor(A)
self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots))

sub_test(True)
Expand Down Expand Up @@ -7242,7 +7233,7 @@ def run_test(A_dims, b_dims, pivot=True):
A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size)
b = make_tensor(b_dims, dtype=dtype, device=device)
x_exp = np.linalg.solve(A.cpu(), b.cpu())
LU_data, LU_pivots = torch.lu(A, pivot=pivot)
LU_data, LU_pivots = torch.linalg.lu_factor(A)
x = torch.lu_solve(b, LU_data, LU_pivots)
self.assertEqual(x, x_exp)

Expand Down
12 changes: 6 additions & 6 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5803,16 +5803,16 @@ def merge_dicts(*dicts):
lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor
Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted
LU factorization of A from :meth:`torch.lu`.
LU factorization of A from :func:`~linalg.lu_factor`.
This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`.
Arguments:
b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*`
is zero or more batch dimensions.
LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu` of size :math:`(*, m, m)`,
LU_data (Tensor): the pivoted LU factorization of A from :meth:`~linalg.lu_factor` of size :math:`(*, m, m)`,
where :math:`*` is zero or more batch dimensions.
LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`torch.lu` of size :math:`(*, m)`,
LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`~linalg.lu_factor` of size :math:`(*, m)`,
where :math:`*` is zero or more batch dimensions.
The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of
:attr:`LU_data`.
Expand All @@ -5824,9 +5824,9 @@ def merge_dicts(*dicts):
>>> A = torch.randn(2, 3, 3)
>>> b = torch.randn(2, 3, 1)
>>> A_LU = torch.lu(A)
>>> x = torch.lu_solve(b, *A_LU)
>>> torch.norm(torch.bmm(A, x) - b)
>>> LU, pivots = torch.linalg.lu_factor(A)
>>> x = torch.lu_solve(b, LU, pivots)
>>> torch.dist(A @ x, b)
tensor(1.00000e-07 *
2.8312)
""".format(**common_args))
Expand Down
6 changes: 5 additions & 1 deletion torch/backends/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,14 @@ def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend]
* :func:`torch.linalg.cholesky_ex`
* :func:`torch.cholesky_solve`
* :func:`torch.cholesky_inverse`
* :func:`torch.lu`
* :func:`torch.linalg.lu_factor`
* :func:`torch.linalg.lu`
* :func:`torch.linalg.lu_solve`
* :func:`torch.linalg.qr`
* :func:`torch.linalg.eigh`
* :func:`torch.linalg.eighvals`
* :func:`torch.linalg.svd`
* :func:`torch.linalg.svdvals`
'''

if backend is None:
Expand Down
17 changes: 17 additions & 0 deletions torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,23 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to
``True``.

.. warning::

:func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor`
and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a
future PyTorch release.
``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with

.. code:: python

LU, pivots = torch.linalg.lu_factor(A, compute_pivots)

``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with

.. code:: python

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)

.. note::
* The returned permutation matrix for every matrix in the batch is
represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``.
Expand Down
7 changes: 6 additions & 1 deletion torch/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@
Also supports batches of matrices, and if :attr:`A` is a batch of matrices then
the output has the same batch dimensions.

""" + fr"""
.. note:: This function is computed using :func:`torch.linalg.lu_factor`.
{common_notes["sync_note"]}
""" + r"""

.. seealso::

:func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the
Expand Down Expand Up @@ -372,7 +377,7 @@
the output has the same batch dimensions.

""" + fr"""
.. note:: This function is computed using :func:`torch.lu`.
.. note:: This function is computed using :func:`torch.linalg.lu_factor`.
{common_notes["sync_note"]}
""" + r"""

Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6904,7 +6904,7 @@ def sample_inputs_linalg_lu(op_info, device, dtype, requires_grad=False, **kwarg
make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)

def out_fn(output):
if op_info.name in ("linalg.lu"):
if op_info.name == "linalg.lu":
return output[1], output[2]
else:
return output
Expand Down

0 comments on commit f7b9a46

Please sign in to comment.