Skip to content

Commit

Permalink
fix tensor_to_vec
Browse files Browse the repository at this point in the history
  • Loading branch information
caglayantuna committed May 12, 2021
1 parent d38084d commit d001f75
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tensorly/tenalg/tests/test_proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ..proximal import svd_thresholding, soft_thresholding, hals_nnls, fista, active_set_nnls
from ..proximal import procrustes
from ...testing import assert_array_equal, assert_array_almost_equal
from tensorly import tensor_to_vec

# Author: Jean Kossaifi

Expand Down Expand Up @@ -76,7 +77,7 @@ def test_hals_nnls():
ata = T.dot(T.transpose(a), a)
xinit = T.zeros(T.shape(atb))
x_hals = hals_nnls(atb, ata, V=xinit, exact=True)[0]
assert_array_almost_equal(true_res, x_hals, decimal=3)
assert_array_almost_equal(true_res, x_hals, decimal=2)


def test_fista():
Expand All @@ -97,6 +98,6 @@ def test_active_set_nnls():
b = T.dot(a, true_res)
atb = T.dot(T.transpose(a), b)
ata = T.dot(T.transpose(a), a)
x_as = active_set_nnls(T.tensor_to_vec(atb), ata)
x_as = active_set_nnls(tensor_to_vec(atb), ata)
x_as = T.reshape(x_as, T.shape(atb))
assert_array_almost_equal(true_res, x_as, decimal=3)

0 comments on commit d001f75

Please sign in to comment.