diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 00f9e0189c545d..2e75147e6c6911 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1343,7 +1343,7 @@ bool _requires_fw_or_bw_grad(const Tensor& input) { // Below of the definitions of the functions operating on a batch that are going to be dispatched // in the main helper functions for the linear algebra operations -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Solves a system of linear equations matmul(input, x) = other in-place // LAPACK/MAGMA error codes are saved in 'infos' tensor, they are not checked here @@ -2073,6 +2073,17 @@ std::tuple linalg_lu_factor(const Tensor& A, bool pivot) { // TODO Deprecate this function in favour of linalg_lu_factor_ex std::tuple _lu_with_info(const Tensor& self, bool compute_pivots, bool) { + TORCH_WARN_ONCE( + "torch.lu is deprecated in favor of torch.linalg.lu_factor / torch.linalg.lu_factor_ex and will be ", + "removed in a future PyTorch release.\n", + "LU, pivots = torch.lu(A, compute_pivots)\n", + "should be replaced with\n", + "LU, pivots = torch.linalg.lu_factor(A, compute_pivots)\n", + "and\n", + "LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)\n", + "should be replaced with\n", + "LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)" + ); return at::linalg_lu_factor_ex(self, compute_pivots, false); } diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp index 8b4ddb6847083f..bbc9f2caac94f6 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp @@ -1920,7 +1920,7 @@ static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& piv #if !AT_MAGMA_ENABLED() TORCH_CHECK( false, - "Calling torch.lu on a CUDA tensor requires compiling ", + "Calling linalg.lu_factor on a CUDA tensor requires compiling ", "PyTorch with MAGMA. Please rebuild with MAGMA."); #else auto input_data = input.data_ptr(); diff --git a/test/mobile/model_test/math_ops.py b/test/mobile/model_test/math_ops.py index f89e3bca70d6d3..551c712ed38bb9 100644 --- a/test/mobile/model_test/math_ops.py +++ b/test/mobile/model_test/math_ops.py @@ -441,9 +441,9 @@ def blas_lapack_ops(self): # torch.logdet(m), # torch.slogdet(m), # torch.lstsq(m, m), - # torch.lu(m), - # torch.lu_solve(m, *torch.lu(m)), - # torch.lu_unpack(*torch.lu(m)), + # torch.linalg.lu_factor(m), + # torch.lu_solve(m, *torch.linalg.lu_factor(m)), + # torch.lu_unpack(*torch.linalg.lu_factor(m)), torch.matmul(m, m), torch.matrix_power(m, 2), # torch.matrix_rank(m), diff --git a/test/test_jit.py b/test/test_jit.py index 9a5a3d143e1e6a..b063dd65c56d2e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9150,20 +9150,8 @@ def istft(input, n_fft): inps2 = (stft(*inps), inps[1]) self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2)) - def lu(x): - # type: (Tensor) -> Tuple[Tensor, Tensor] - return torch.lu(x) - - self.checkScript(lu, (torch.randn(2, 3, 3),)) - - def lu_infos(x): - # type: (Tensor) -> Tuple[Tensor, Tensor, Tensor] - return torch.lu(x, get_infos=True) - - self.checkScript(lu_infos, (torch.randn(2, 3, 3),)) - def lu_unpack(x): - A_LU, pivots = torch.lu(x) + A_LU, pivots = torch.linalg.lu_factor(x) return torch.lu_unpack(A_LU, pivots) for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)): diff --git a/test/test_linalg.py b/test/test_linalg.py index 3da4aa1cf74b3b..206052a7422e00 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -4213,7 +4213,7 @@ def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_c size_b = size_b[1:] if well_conditioned: - PLU = torch.lu_unpack(*torch.lu(make_randn(*size_a))) + PLU = torch.linalg.lu(make_randn(*size_a)) if uni: # A = L from PLU A = PLU[1].transpose(-2, -1).contiguous() @@ -4900,15 +4900,6 @@ def call_torch_fn(*args, **kwargs): self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,))) self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True)) - if torch._C.has_lapack: - # lu - A_LU, pivots = fn(torch.lu, (0, 5, 5)) - self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape]) - A_LU, pivots = fn(torch.lu, (0, 0, 0)) - self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape]) - A_LU, pivots = fn(torch.lu, (2, 0, 0)) - self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape]) - @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) @dtypesIfCUDA(*floating_and_complex_types_and( @@ -5276,7 +5267,7 @@ def gen_matrices(): @dtypes(torch.double) def test_lu_unpack_check_input(self, device, dtype): x = torch.rand(5, 5, 5, device=device, dtype=dtype) - lu_data, lu_pivots = torch.lu(x, pivot=True) + lu_data, lu_pivots = torch.linalg.lu_factor(x) with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"): torch.lu_unpack(lu_data, lu_pivots.long()) @@ -7163,7 +7154,7 @@ def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype): b = torch.randn(*b_dims, dtype=dtype, device=device) A = make_A(*A_dims) - LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot) + LU_data, LU_pivots, info = torch.linalg.lu_factor_ex(A) self.assertEqual(info, torch.zeros_like(info)) return b, A, LU_data, LU_pivots @@ -7207,7 +7198,7 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot): # Tests tensors with 0 elements b = torch.randn(3, 0, 3, dtype=dtype, device=device) A = torch.randn(3, 0, 0, dtype=dtype, device=device) - LU_data, LU_pivots = torch.lu(A) + LU_data, LU_pivots = torch.linalg.lu_factor(A) self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots)) sub_test(True) @@ -7242,7 +7233,7 @@ def run_test(A_dims, b_dims, pivot=True): A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size) b = make_tensor(b_dims, dtype=dtype, device=device) x_exp = np.linalg.solve(A.cpu(), b.cpu()) - LU_data, LU_pivots = torch.lu(A, pivot=pivot) + LU_data, LU_pivots = torch.linalg.lu_factor(A) x = torch.lu_solve(b, LU_data, LU_pivots) self.assertEqual(x, x_exp) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index f83904ac6cfe91..18e01ff419318d 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5803,16 +5803,16 @@ def merge_dicts(*dicts): lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted -LU factorization of A from :meth:`torch.lu`. +LU factorization of A from :func:`~linalg.lu_factor`. This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. Arguments: b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*` is zero or more batch dimensions. - LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu` of size :math:`(*, m, m)`, + LU_data (Tensor): the pivoted LU factorization of A from :meth:`~linalg.lu_factor` of size :math:`(*, m, m)`, where :math:`*` is zero or more batch dimensions. - LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`torch.lu` of size :math:`(*, m)`, + LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`~linalg.lu_factor` of size :math:`(*, m)`, where :math:`*` is zero or more batch dimensions. The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of :attr:`LU_data`. @@ -5824,9 +5824,9 @@ def merge_dicts(*dicts): >>> A = torch.randn(2, 3, 3) >>> b = torch.randn(2, 3, 1) - >>> A_LU = torch.lu(A) - >>> x = torch.lu_solve(b, *A_LU) - >>> torch.norm(torch.bmm(A, x) - b) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> x = torch.lu_solve(b, LU, pivots) + >>> torch.dist(A @ x, b) tensor(1.00000e-07 * 2.8312) """.format(**common_args)) diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 6b823389c8e410..ec7a68ec5be046 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -134,10 +134,14 @@ def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend] * :func:`torch.linalg.cholesky_ex` * :func:`torch.cholesky_solve` * :func:`torch.cholesky_inverse` - * :func:`torch.lu` + * :func:`torch.linalg.lu_factor` + * :func:`torch.linalg.lu` + * :func:`torch.linalg.lu_solve` * :func:`torch.linalg.qr` * :func:`torch.linalg.eigh` + * :func:`torch.linalg.eighvals` * :func:`torch.linalg.svd` + * :func:`torch.linalg.svdvals` ''' if backend is None: diff --git a/torch/functional.py b/torch/functional.py index 29a66f7f160084..75d8e365140aa1 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1552,6 +1552,23 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None): pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to ``True``. + .. warning:: + + :func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor` + and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a + future PyTorch release. + ``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with + + .. code:: python + + LU, pivots = torch.linalg.lu_factor(A, compute_pivots) + + ``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with + + .. code:: python + + LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) + .. note:: * The returned permutation matrix for every matrix in the batch is represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``. diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 4c71919c3cd7e2..9cd50a88b44b6a 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -336,6 +336,11 @@ Also supports batches of matrices, and if :attr:`A` is a batch of matrices then the output has the same batch dimensions. +""" + fr""" +.. note:: This function is computed using :func:`torch.linalg.lu_factor`. + {common_notes["sync_note"]} +""" + r""" + .. seealso:: :func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the @@ -372,7 +377,7 @@ the output has the same batch dimensions. """ + fr""" -.. note:: This function is computed using :func:`torch.lu`. +.. note:: This function is computed using :func:`torch.linalg.lu_factor`. {common_notes["sync_note"]} """ + r""" diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 7d2e2990b6950b..ec671cee7f19a9 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6904,7 +6904,7 @@ def sample_inputs_linalg_lu(op_info, device, dtype, requires_grad=False, **kwarg make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad) def out_fn(output): - if op_info.name in ("linalg.lu"): + if op_info.name == "linalg.lu": return output[1], output[2] else: return output