Skip to content

Commit

Permalink
Updated
Browse files Browse the repository at this point in the history
  • Loading branch information
Bo Li committed Jun 20, 2021
1 parent 3b58739 commit 68b3bf7
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 0 deletions.
4 changes: 4 additions & 0 deletions nmf/_nmf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
l1_ratio_H: float,
fp_precision: Union[str, torch.dtype],
device_type: str,
n_jobs: int,
):
self.k = n_components
self._beta = beta_loss
Expand Down Expand Up @@ -50,6 +51,9 @@ def __init__(
self._tol = tol
self._random_state = random_state

if n_jobs > 0:
torch.set_num_threads(n_jobs)


def _get_regularization_loss(self, mat, l1_reg, l2_reg):
res = 0.0
Expand Down
2 changes: 2 additions & 0 deletions nmf/_nmf_batch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
l1_ratio_H: float,
fp_precision: Union[str, torch.dtype],
device_type: str,
n_jobs: int = -1,
max_iter: int = 500,
):
super().__init__(
Expand All @@ -31,6 +32,7 @@ def __init__(
l1_ratio_H=l1_ratio_H,
fp_precision=fp_precision,
device_type=device_type,
n_jobs=n_jobs,
)

self._max_iter = max_iter
Expand Down
2 changes: 2 additions & 0 deletions nmf/_nmf_batch_hals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
l1_ratio_H: float,
fp_precision: Union[str, torch.dtype],
device_type: str,
n_jobs: int = -1,
max_iter: int = 500,
hals_tol: float = 0.05,
hals_max_iter: int = 200,
Expand All @@ -36,6 +37,7 @@ def __init__(
l1_ratio_H=l1_ratio_H,
fp_precision=fp_precision,
device_type=device_type,
n_jobs=n_jobs,
max_iter=max_iter,
)

Expand Down
2 changes: 2 additions & 0 deletions nmf/_nmf_batch_nnls_bpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
l1_ratio_H: float,
fp_precision: Union[str, torch.dtype],
device_type: str,
n_jobs: int = -1,
max_iter: int = 500,
):
assert beta_loss == 2.0 # only work for F norm for now
Expand All @@ -35,6 +36,7 @@ def __init__(
l1_ratio_H=l1_ratio_H,
fp_precision=fp_precision,
device_type=device_type,
n_jobs=n_jobs,
max_iter=max_iter,
)

Expand Down
2 changes: 2 additions & 0 deletions nmf/_nmf_online_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
l1_ratio_H: float,
fp_precision: Union[str, torch.dtype],
device_type: str,
n_jobs: int = -1,
max_pass: int = 20,
chunk_size: int = 5000,
):
Expand All @@ -34,6 +35,7 @@ def __init__(
l1_ratio_H=l1_ratio_H,
fp_precision=fp_precision,
device_type=device_type,
n_jobs=n_jobs,
)

self._max_pass = max_pass
Expand Down
2 changes: 2 additions & 0 deletions nmf/_nmf_online_hals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
l1_ratio_H: float,
fp_precision: Union[str, torch.dtype],
device_type: str,
n_jobs: int = -1,
max_pass: int = 20,
chunk_size: int = 5000,
chunk_max_iter: int = 200,
Expand All @@ -36,6 +37,7 @@ def __init__(
l1_ratio_H=l1_ratio_H,
fp_precision=fp_precision,
device_type=device_type,
n_jobs=n_jobs,
max_pass=max_pass,
chunk_size=chunk_size,
)
Expand Down
2 changes: 2 additions & 0 deletions nmf/_nmf_online_mu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
l1_ratio_H: float,
fp_precision: Union[str, torch.dtype],
device_type: str,
n_jobs: int = -1,
max_pass: int = 20,
chunk_size: int = 5000,
chunk_max_iter: int = 200,
Expand All @@ -36,6 +37,7 @@ def __init__(
l1_ratio_H=l1_ratio_H,
fp_precision=fp_precision,
device_type=device_type,
n_jobs=n_jobs,
max_pass=max_pass,
chunk_size=chunk_size,
)
Expand Down
2 changes: 2 additions & 0 deletions nmf/_nmf_online_nnls_bpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
l1_ratio_H: float,
fp_precision: Union[str, torch.dtype],
device_type: str,
n_jobs: int = -1,
max_pass: int = 20,
chunk_size: int = 5000,
):
Expand All @@ -34,6 +35,7 @@ def __init__(
l1_ratio_H=l1_ratio_H,
fp_precision=fp_precision,
device_type=device_type,
n_jobs=n_jobs,
max_pass=max_pass,
chunk_size=chunk_size,
)
Expand Down
3 changes: 3 additions & 0 deletions nmf/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def run_nmf(
algo: str = "hals",
mode: str = "batch",
tol: float = 1e-4,
n_jobs: int = -1,
random_state: int = 0,
use_gpu: bool = False,
alpha_W: float = 0.0,
Expand Down Expand Up @@ -79,6 +80,8 @@ def run_nmf(
Learning mode. Choose from ``batch`` and ``online``. Notice that ``online`` only works when ``beta=2.0``. For other beta loss, it switches back to ``batch`` method.
tol: ``float``, optional, default: ``1e-4``
The toleration used for convergence check.
n_jobs: ``int``, optional, default: ``-1``
Number of cpu threads to use. If -1, use PyTorch's default setting.
random_state: ``int``, optional, default: ``0``
The random state used for reproducibility on the results.
use_gpu: ``bool``, optional, default: ``False``
Expand Down

0 comments on commit 68b3bf7

Please sign in to comment.