Skip to content

Commit

Permalink
Merge remote-tracking branch 'tensorly/main' into parafac2_stopping_c…
Browse files Browse the repository at this point in the history
…ondition
  • Loading branch information
MarieRoald committed May 6, 2021
2 parents b3b6faf + 23ba632 commit c225e48
Show file tree
Hide file tree
Showing 15 changed files with 321 additions and 52 deletions.
1 change: 1 addition & 0 deletions doc/modules/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ Functions

regression.MSE
regression.RMSE
factors.congruence_coefficient


:mod:`tensorly.random`: Sampling tensors
Expand Down
21 changes: 20 additions & 1 deletion tensorly/backend/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,26 @@ def where(condition, x, y):
"""
raise NotImplementedError

@staticmethod
def any(tensor, axis=None, keepdims=False, **kwargs):
"""Test whether any array element along a given axis evaluates to True.
Parameters
----------
tensor : tensor
input tensor to check for non-zero values
axis : int or None, default is None
optional, indicates an axis along which to check for non-zero values
keepdims : bool, default is False
Returns
-------
bool or tensor
if axis is None, returns a bool indicating whether any value is non-zero
otherwise, returns a tensor of bools.
"""
return tensor.any(axis=axis, keepdims=keepdims, **kwargs)

@staticmethod
def clip(tensor, a_min=None, a_max=None):
"""Clip the values of a tensor to within an interval.
Expand Down Expand Up @@ -702,7 +722,6 @@ def conj(x, *args, **kwargs):
"""
raise NotImplementedError


@staticmethod
def sort(tensor, axis, descending = False):
"""Return a sorted copy of an array
Expand Down
2 changes: 1 addition & 1 deletion tensorly/backend/cupy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def sort(tensor, axis, descending = False):


for name in ['float64', 'float32', 'int64', 'int32', 'complex128', 'complex64', 'reshape', 'moveaxis',
'transpose', 'copy', 'ones', 'zeros', 'zeros_like', 'eye',
'transpose', 'copy', 'ones', 'zeros', 'zeros_like', 'eye', 'any',
'arange', 'where', 'dot', 'kron', 'concatenate', 'max', 'flip',
'min', 'all', 'mean', 'sum', 'prod', 'sign', 'abs', 'sqrt', 'stack',
'conj', 'diag', 'einsum', 'log2', 'tensordot']:
Expand Down
2 changes: 1 addition & 1 deletion tensorly/backend/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def sort(tensor, axis, descending = False):
return np.sort(tensor, axis=axis)

for name in ['int64', 'int32', 'float64', 'float32', 'complex128', 'complex64', 'reshape', 'moveaxis',
'where', 'transpose', 'arange', 'ones', 'zeros', 'flip',
'where', 'transpose', 'arange', 'ones', 'zeros', 'flip', 'any',
'zeros_like', 'eye', 'kron', 'concatenate', 'max', 'min',
'all', 'mean', 'sum', 'prod', 'sign', 'abs', 'sqrt', 'argmin',
'argmax', 'stack', 'conj', 'diag', 'clip', 'einsum', 'log2', 'tensordot', 'sin', 'cos']:
Expand Down
2 changes: 1 addition & 1 deletion tensorly/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def sort(tensor, axis, descending = False):


for name in ['int64', 'int32', 'float64', 'float32', 'reshape', 'moveaxis',
'where', 'copy', 'transpose', 'arange', 'ones', 'zeros',
'where', 'copy', 'transpose', 'arange', 'ones', 'zeros', 'any',
'zeros_like', 'eye', 'concatenate', 'max', 'min', 'flip',
'all', 'mean', 'sum', 'prod', 'sign', 'abs', 'sqrt', 'argmin',
'argmax', 'stack', 'diag', 'einsum', 'log2', 'tensordot', 'sin', 'cos']:
Expand Down
2 changes: 1 addition & 1 deletion tensorly/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def sort(tensor, axis, descending = False):
return np.sort(tensor, axis=axis)

for name in ['int64', 'int32', 'float64', 'float32', 'complex128', 'complex64',
'reshape', 'moveaxis',
'reshape', 'moveaxis', 'any',
'where', 'copy', 'transpose', 'arange', 'ones', 'zeros', 'flip',
'zeros_like', 'eye', 'kron', 'concatenate', 'max', 'min',
'all', 'mean', 'sum', 'prod', 'sign', 'abs', 'sqrt', 'argmin',
Expand Down
2 changes: 1 addition & 1 deletion tensorly/backend/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def svd(matrix, full_matrices=False):

# Register the other functions
for name in ['float64', 'float32', 'int64', 'int32', 'complex128', 'complex64',
'is_tensor', 'ones', 'zeros',
'is_tensor', 'ones', 'zeros', 'any',
'zeros_like', 'reshape', 'eye', 'max', 'min', 'prod', 'abs',
'sqrt', 'sign', 'where', 'conj', 'diag', 'finfo', 'einsum', 'log2', 'sin', 'cos']:
PyTorchBackend.register_method(name, getattr(torch, name))
Expand Down
3 changes: 2 additions & 1 deletion tensorly/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def log2(self,x):
(tf.einsum, 'einsum'),
(tf.tensordot, 'tensordot'),
(tfm.sin, 'sin'),
(tfm.cos, 'cos')
(tfm.cos, 'cos'),
(tfm.reduce_any, 'any')
]
for source_fun, target_fun_name in _FUN_NAMES:
TensorflowBackend.register_method(target_fun_name, source_fun)
Expand Down
32 changes: 25 additions & 7 deletions tensorly/decomposition/_nn_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def non_negative_parafac(tensor, rank, n_iter_max=100, init='svd', svd='numpy_sv


def non_negative_parafac_hals(tensor, rank, n_iter_max=100, init="svd", svd='numpy_svd', tol=10e-8,
sparsity_coefficients=None, fixed_modes=None, exact=False,
sparsity_coefficients=None, fixed_modes=None, nn_modes='all', exact=False,
verbose=False, return_errors=False, cvg_criterion='abs_rec_error'):
"""
Non-negative CP decomposition via HALS
Expand Down Expand Up @@ -325,6 +325,10 @@ def non_negative_parafac_hals(tensor, rank, n_iter_max=100, init="svd", svd='num
fixed_modes: array of integers (between 0 and the number of modes)
Has to be set not to update a factor, 0 and 1 for U and V respectively
Default: None
nn_modes: None, 'all' or array of integers (between 0 and the number of modes)
Used to specify which modes to impose non-negativity constraints on.
If 'all', then non-negativity is imposed on all modes.
Default: 'all'
exact: If it is True, the algorithm gives a results with high precision but it needs high computational cost.
If it is False, the algorithm gives an approximate solution
Default: False
Expand Down Expand Up @@ -370,10 +374,18 @@ def non_negative_parafac_hals(tensor, rank, n_iter_max=100, init="svd", svd='num
if fixed_modes is None:
fixed_modes = []

if nn_modes == 'all':
nn_modes = set(range(n_modes))
elif nn_modes is None:
nn_modes = set()

# Avoiding errors
for fixed_value in fixed_modes:
sparsity_coefficients[fixed_value] = None

for mode in range(n_modes):
if sparsity_coefficients[mode] is not None:
warnings.warn("Sparsity coefficient is ignored in unconstrained modes.")
# Generating the mode update sequence
modes = [mode for mode in range(n_modes) if mode not in fixed_modes]

Expand All @@ -397,11 +409,15 @@ def non_negative_parafac_hals(tensor, rank, n_iter_max=100, init="svd", svd='num
else:
mttkrp = unfolding_dot_khatri_rao(tensor, (None, factors), mode)

# Call the hals resolution with nnls, optimizing the current mode
nn_factor, _, _, _ = hals_nnls(tl.transpose(mttkrp), pseudo_inverse, tl.transpose(factors[mode]),
n_iter_max=100, sparsity_coefficient=sparsity_coefficients[mode],
exact=exact)
factors[mode] = tl.transpose(nn_factor)
if mode in nn_modes:
# Call the hals resolution with nnls, optimizing the current mode
nn_factor, _, _, _ = hals_nnls(tl.transpose(mttkrp), pseudo_inverse, tl.transpose(factors[mode]),
n_iter_max=100, sparsity_coefficient=sparsity_coefficients[mode],
exact=exact)
factors[mode] = tl.transpose(nn_factor)
else:
factor = tl.solve(tl.transpose(pseudo_inverse), tl.transpose(mttkrp))
factors[mode] = tl.transpose(factor)
if tol:
factors_norm = cp_norm((weights, factors))
iprod = tl.sum(tl.sum(mttkrp * factor, axis=0) * weights)
Expand Down Expand Up @@ -642,10 +658,12 @@ def __init__(self, rank, n_iter_max=100, tol=1e-08,
init='svd', svd='numpy_svd',
l2_reg=0,
fixed_modes=None,
nn_modes='all',
normalize_factors=False,
sparsity=None,
exact=False,
mask=None, svd_mask_repeats=5,
mask=None,
svd_mask_repeats=5,
return_errors=True,
cvg_criterion='abs_rec_error',
random_state=None,
Expand Down
34 changes: 28 additions & 6 deletions tensorly/decomposition/_parafac2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from warnings import warn

import tensorly as tl
from ._base_decomposition import DecompositionMixin
from tensorly.random import random_parafac2
from tensorly import backend as T
from . import parafac
from . import parafac, non_negative_parafac_hals
from ..parafac2_tensor import parafac2_to_slice, Parafac2Tensor, _validate_parafac2_tensor
from ..cp_tensor import CPTensor
from ..base import unfold
Expand Down Expand Up @@ -133,7 +135,7 @@ def _parafac2_reconstruction_error(tensor_slices, decomposition):


def parafac2(tensor_slices, rank, n_iter_max=100, init='random', svd='numpy_svd', normalize_factors=False,
tol=1e-8, random_state=None, verbose=False, return_errors=False, n_iter_parafac=5, loss_decrease_tol='relative'):
tol=1e-8, nn_modes=None, random_state=None, verbose=False, return_errors=False, n_iter_parafac=5,):
r"""PARAFAC2 decomposition [1]_ of a third order tensor via alternating least squares (ALS)
Computes a rank-`rank` PARAFAC2 decomposition of the third-order tensor defined by
Expand All @@ -156,6 +158,7 @@ def parafac2(tensor_slices, rank, n_iter_max=100, init='random', svd='numpy_svd'
where :math:`P_i` is a :math:`J_i \times R` orthogonal matrix and :math:`B` is a
:math:`R \times R` matrix.
An alternative formulation of the PARAFAC2 decomposition is that the tensor element
:math:`X_{ijk}` is given by
Expand All @@ -173,7 +176,7 @@ def parafac2(tensor_slices, rank, n_iter_max=100, init='random', svd='numpy_svd'
Either a third order tensor or a list of second order tensors that may have different number of rows.
Note that the second mode factor matrices are allowed to change over the first mode, not the
third mode as some other implementations use (see note below).
rank : int
rank : int
Number of components.
n_iter_max : int
Maximum number of iteration
Expand Down Expand Up @@ -231,13 +234,34 @@ def parafac2(tensor_slices, rank, n_iter_max=100, init='random', svd='numpy_svd'
[1]_, the second mode changes over the third mode. We made this change since that means
that the function accept both lists of matrices and a single nd-array as input without
any reordering of the modes.
Because of the reformulation above, :math:`B_i = P_i B`, the :math:`B_i` matrices
cannot be constrained to be non-negative with ALS. If this mode is constrained to be
non-negative, then :math:`B` will be non-negative, but not the orthogonal `P_i` matrices.
Consequently, the `B_i` matrices are unlikely to be non-negative.
"""
weights, factors, projections = initialize_decomposition(tensor_slices, rank, init=init, svd=svd, random_state=random_state)

rec_errors = []
norm_tensor = tl.sqrt(sum(tl.norm(tensor_slice, 2)**2 for tensor_slice in tensor_slices))
svd_fun = _get_svd(svd)

# If nn_modes is set, we use HALS, otherwise, we use the standard parafac implementation.
if nn_modes is None:
def parafac_updates(X, w, f):
return parafac(X, rank, n_iter_max=n_iter_parafac,
init=(w, f), svd=svd, orthogonalise=False, verbose=verbose,
return_errors=False, normalize_factors=False, mask=None,
random_state=random_state, tol=1e-100)[1]
else:
if nn_modes == 'all' or 1 in nn_modes:
warn("Mode `1` of PARAFAC2 fitted with ALS cannot be constrained to be truly non-negative. See the documentation for more info.")
def parafac_updates(X, w, f):
return non_negative_parafac_hals(
X, rank, n_iter_max=n_iter_parafac, init=(w, f), svd=svd, nn_modes=nn_modes,
verbose=verbose, return_errors=False, tol=1e-100)[1]


projected_tensor = tl.zeros([factor.shape[0] for factor in factors], **T.context(factors[0]))

for iteration in range(n_iter_max):
Expand All @@ -248,9 +272,7 @@ def parafac2(tensor_slices, rank, n_iter_max=100, init='random', svd='numpy_svd'

projections = _compute_projections(tensor_slices, factors, svd_fun, out=projections)
projected_tensor = _project_tensor_slices(tensor_slices, projections, out=projected_tensor)
_, factors = parafac(projected_tensor, rank, n_iter_max=n_iter_parafac, init=(weights, factors),
svd=svd, orthogonalise=False, verbose=verbose, return_errors=False,
normalize_factors=False, mask=None, random_state=random_state, tol=1e-100)
factors = parafac_updates(projected_tensor, weights, factors)

if normalize_factors:
new_factors = []
Expand Down
32 changes: 32 additions & 0 deletions tensorly/decomposition/tests/test_cp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import numpy as np
import pytest

Expand All @@ -11,6 +12,7 @@
from ...tenalg import khatri_rao
from ... import backend as T
from ...testing import assert_array_equal, assert_
from ...metrics.factors import congruence_coefficient


@pytest.mark.parametrize("linesearch", [True, False])
Expand Down Expand Up @@ -205,6 +207,36 @@ def test_non_negative_parafac_hals():
'norm 2 of difference between svd and random init too high')
assert_(tl.max(tl.abs(rec_svd - rec_random)) < tol_max_abs,
'abs norm of difference between svd and random init too high')


def test_non_negative_parafac_hals_one_unconstrained():
"""Test for non-negative PARAFAC HALS
TODO: more rigorous test
"""
rng = tl.check_random_state(1234)
t_shape = (8, 9, 10)
rank = 3
weights = T.tensor(rng.uniform(size=rank))
A = T.tensor(rng.uniform(size=(t_shape[0], rank)))
B = T.tensor(rng.standard_normal(size=(t_shape[1], rank)))
C = T.tensor(rng.uniform(0.1, 1.1, size=(t_shape[2], rank)))
cp_tensor = (weights, (A, B, C))
X = cp_to_tensor(cp_tensor)

nn_estimate, errs = non_negative_parafac_hals(
X, rank=3, n_iter_max=100, tol=0, init='svd', verbose=0, nn_modes={0, 2}, return_errors=True
)
X_hat = cp_to_tensor(nn_estimate)
assert_(tl.norm(X - X_hat,) < 1e-3, "Error was too high")

assert_(congruence_coefficient(A, nn_estimate[1][0], absolute_value=True)[0] > 0.99, "Factor recovery not high enough")
assert_(congruence_coefficient(B, nn_estimate[1][1], absolute_value=True)[0] > 0.99, "Factor recovery not high enough")
assert_(congruence_coefficient(C, nn_estimate[1][2], absolute_value=True)[0] > 0.99, "Factor recovery not high enough")

assert_(T.all(nn_estimate[1][0] > -1e-10))
assert_(T.all(nn_estimate[1][2] > -1e-10))


@pytest.mark.xfail(tl.get_backend() == 'tensorflow', reason='Fails on tensorflow')
def test_sample_khatri_rao():
""" Test for sample_khatri_rao
Expand Down
Loading

0 comments on commit c225e48

Please sign in to comment.