diff --git a/nmf/__init__.py b/nmf/__init__.py index e7d8f23..03bd135 100644 --- a/nmf/__init__.py +++ b/nmf/__init__.py @@ -1,19 +1,19 @@ -from ._nmf_batch_mu import NMFBatchMU -from ._nmf_batch_hals import NMFBatchHALS -from ._nmf_batch_nnls_bpp import NMFBatchNnlsBpp -from ._nmf_online_mu import NMFOnlineMU -from ._nmf_online_hals import NMFOnlineHALS -from ._nmf_online_nnls_bpp import NMFOnlineNnlsBpp +#from ._nmf_batch_mu import NMFBatchMU +#from ._nmf_batch_hals import NMFBatchHALS +#from ._nmf_batch_nnls_bpp import NMFBatchNnlsBpp +#from ._nmf_online_mu import NMFOnlineMU +#from ._nmf_online_hals import NMFOnlineHALS +#from ._nmf_online_nnls_bpp import NMFOnlineNnlsBpp -from .nmf import run_nmf +from .nmf import run_nmf, integrative_nmf -from ._inmf_batch_mu import INMFBatchMU -from ._inmf_batch_hals import INMFBatchHALS -from ._inmf_batch_nnls_bpp import INMFBatchNnlsBpp -from ._inmf_online_mu import INMFOnlineMU -from ._inmf_online_hals import INMFOnlineHALS -from ._inmf_online_nnls_bpp import INMFOnlineNnlsBpp +#from ._inmf_batch_mu import INMFBatchMU +#from ._inmf_batch_hals import INMFBatchHALS +#from ._inmf_batch_nnls_bpp import INMFBatchNnlsBpp +#from ._inmf_online_mu import INMFOnlineMU +#from ._inmf_online_hals import INMFOnlineHALS +#from ._inmf_online_nnls_bpp import INMFOnlineNnlsBpp try: from importlib.metadata import version, PackageNotFoundError diff --git a/nmf/inmf/_inmf_online.py b/nmf/inmf/_inmf_online.py deleted file mode 100644 index 71175e7..0000000 --- a/nmf/inmf/_inmf_online.py +++ /dev/null @@ -1,290 +0,0 @@ -import torch - -from ._inmf_base import INMFBase -from typing import List, Union - -class INMFOnline(INMFBase): - def __init__( - self, - n_components: int, - lam: float = 5., - init: str = 'random', - tol: float = 1e-4, - random_state: int = 0, - fp_precision: Union[str, torch.dtype] = 'float', - device_type: str = 'cpu', - max_pass: int = 10, - chunk_size: int = 2000, - w_max_iter: int = 200, - v_max_iter: int = 50, - h_max_iter: int = 50, - w_tol: float = 1e-4, - v_tol: float = 1e-4, - h_tol: float = 1e-4, - ): - super().__init__( - n_components=n_components, - lam=lam, - init=init, - tol=tol, - random_state=random_state, - fp_precision=fp_precision, - device_type=device_type, - ) - - self._max_pass = max_pass - self._chunk_size = chunk_size - self._w_max_iter = w_max_iter - self._v_max_iter = v_max_iter - self._h_max_iter = h_max_iter - self._w_tol = w_tol - self._v_tol = v_tol - self._h_tol = h_tol - - - def _h_err(self, h, hth, WVWVT, xWVT, VVT): - # Calculate L2 Loss (no sum of squares of X) for block h in trace format. - res = torch.trace((WVWVT + self._lambda * VVT) @ hth) if self._lambda > 0.0 else torch.trace(WVWVT @ hth) - res -= 2.0 * torch.trace(h.T @ xWVT) - return res - - - def _v_err(self, A, B, WV, WVWVT, VVT): - # Calculate L2 Loss (no sum of squares of X) for one batch in trace format. - res = torch.trace((WVWVT + self._lambda * VVT) @ A) if self._lambda > 0.0 else torch.trace(WVWVT @ A) - res -= 2.0 * torch.trace(B @ WV.T) - return res - - - def _w_err(self, CW, E, D): - res = torch.trace((CW + 2.0 * E) @ self.W.T) - 2.0 * torch.trace(D @ self.W.T) - return res - - - def _update_one_pass(self): - """ - A = sum hth; B = sum htx; for each batch - C = sum of hth; D = sum of htx; E = sum of AV; for all batches - """ - A = torch.zeros((self._n_components, self._n_components), dtype=self._tensor_dtype, device=self._device_type) - B = torch.zeros((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type) - C = torch.zeros((self._n_components, self._n_components), dtype=self._tensor_dtype, device=self._device_type) - D = torch.zeros((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type) - E = torch.zeros((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type) - - batch_indices = torch.randperm(self._n_batches, device=self._device_type) - for k in batch_indices: - indices = torch.randperm(self.X[k].shape[0], device=self._device_type) - - # Block-wise update - i = 0 - VVT = self.V[k] @ self.V[k].T if self._lambda > 0.0 else None - A.fill_(0.0) - B.fill_(0.0) - while i < indices.shape[0]: - idx = indices[i:(i+self._chunk_size)] - x = self.X[k][idx, :] - h = self.H[k][idx, :] - - # Update H - WV = self.W + self.V[k] - WVWVT = WV @ WV.T - hth = h.T @ h - xWVT = x @ WV.T - - h_factor_numer = xWVT - cur_h_err = self._h_err(h, hth, WVWVT, xWVT, VVT) - - for j in range(self._h_max_iter): - prev_h_err = cur_h_err - - h_factor_denom = h @ (WVWVT + self._lambda * VVT) if self._lambda > 0.0 else h @ WVWVT - self._update_matrix(h, h_factor_numer, h_factor_denom) - hth = h.T @ h - cur_h_err = self._h_err(h, hth, WVWVT, xWVT, VVT) - - if self._is_converged(prev_h_err, cur_h_err, prev_h_err, self._h_tol): - break - - self.H[k][idx, :] = h - - # Update sufficient statistics for batch k - A += hth - htx = h.T @ x - B += htx - - # Update V - V_factor_numer = B - cur_v_err = self._v_err(A, B, WV, WVWVT, VVT) - - for j in range(self._v_max_iter): - prev_v_err = cur_v_err - - V_factor_denom = A @ (WV + self._lambda * self.V[k]) - self._update_matrix(self.V[k], V_factor_numer, V_factor_denom) - WV = self.W + self.V[k] - WVWVT = WV @ WV.T - VVT = self.V[k] @ self.V[k].T if self._lambda > 0.0 else None - cur_v_err = self._v_err(A, B, WV, WVWVT, VVT) - - if self._is_converged(prev_v_err, cur_v_err, prev_v_err, self._v_tol): - break - - # Update sufficient statistics for all batches - C += hth - D += htx - CW = C @ self.W - E_new = E + A @ self.V[k] - - # Update W - W_factor_numer = D - cur_w_err = self._w_err(CW, E_new, D) - for j in range(self._w_max_iter): - prev_w_err = cur_w_err - - W_factor_denom = CW + E_new - self._update_matrix(self.W, W_factor_numer, W_factor_denom) - CW = C @ self.W - cur_w_err = self._w_err(CW, E_new, D) - - if self._is_converged(prev_w_err, cur_w_err, prev_w_err, self._w_tol): - break - - i += self._chunk_size - E = E_new - - - def _update_H_V(self): - """ - Fix W, only update V and H - A = sum hth; B = sum htx; for each batch - """ - A = torch.zeros((self._n_components, self._n_components), dtype=self._tensor_dtype, device=self._device_type) - B = torch.zeros((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type) - - for k in range(self._n_batches): - indices = torch.randperm(self.X[k].shape[0], device=self._device_type) - - # Block-wise update - i = 0 - WV = self.W + self.V[k] - WVWVT = WV @ WV.T - VVT = self.V[k] @ self.V[k].T if self._lambda > 0.0 else None - A.fill_(0.0) - B.fill_(0.0) - while i < indices.shape[0]: - idx = indices[i:(i+self._chunk_size)] - x = self.X[k][idx, :] - h = self.H[k][idx, :] - - # Update H - hth = h.T @ h - xWVT = x @ WV.T - - h_factor_numer = xWVT - cur_h_err = self._h_err(h, hth, WVWVT, xWVT, VVT) - - for j in range(self._h_max_iter): - prev_h_err = cur_h_err - - h_factor_denom = h @ (WVWVT + self._lambda * VVT) if self._lambda > 0.0 else h @ WVWVT - self._update_matrix(h, h_factor_numer, h_factor_denom) - hth = h.T @ h - cur_h_err = self._h_err(h, hth, WVWVT, xWVT, VVT) - - if self._is_converged(prev_h_err, cur_h_err, prev_h_err, self._h_tol): - break - - self.H[k][idx, :] = h - - # Update sufficient statistics for batch k - A += hth - htx = h.T @ x - B += htx - - # Update V - V_factor_numer = B - cur_v_err = self._v_err(A, B, WV, WVWVT, VVT) - - for j in range(self._v_max_iter): - prev_v_err = cur_v_err - - V_factor_denom = A @ (WV + self._lambda * self.V[k]) - self._update_matrix(self.V[k], V_factor_numer, V_factor_denom) - WV = self.W + self.V[k] - WVWVT = WV @ WV.T - VVT = self.V[k] @ self.V[k].T if self._lambda > 0.0 else None - cur_v_err = self._v_err(A, B, WV, WVWVT, VVT) - - if self._is_converged(prev_v_err, cur_v_err, prev_v_err, self._v_tol): - break - - i += self._chunk_size - - - def _update_H(self): - """ Fix W and V, update H """ - sum_h_err = 0.0 - for k in range(self._n_batches): - WV = self.W + self.V[k] - WVWVT = WV @ WV.T - VVT = self.V[k] @ self.V[k].T if self._lambda > 0.0 else None - - i = 0 - while i < self.H[k].shape[0]: - x = self.X[k][i:(i+self._chunk_size), :] - h = self.H[k][i:(i+self._chunk_size), :] - - # Update H - hth = h.T @ h - xWVT = x @ WV.T - - h_factor_numer = xWVT - cur_h_err = self._h_err(h, hth, WVWVT, xWVT, VVT) - - for j in range(self._h_max_iter): - prev_h_err = cur_h_err - - h_factor_denom = h @ (WVWVT + self._lambda * VVT) if self._lambda > 0.0 else h @ WVWVT - self._update_matrix(h, h_factor_numer, h_factor_denom) - hth = h.T @ h - cur_h_err = self._h_err(h, hth, WVWVT, xWVT, VVT) - - if self._is_converged(prev_h_err, cur_h_err, prev_h_err, self._h_tol): - break - - sum_h_err += cur_h_err - i += self._chunk_size - - return sum_h_err - - - def fit( - self, - mats: List[torch.tensor], - ): - super().fit(mats) - - for i in range(self._max_pass): - self._update_one_pass() - self._update_H_V() - H_err = self._update_H() - - self._cur_err = torch.sqrt(H_err + self._SSX) - if self._is_converged(self._prev_err, self._cur_err, self._init_err, self._tol): - self.num_iters = i + 1 - print(f" Converged after {self.num_iters} pass(es).") - return - - self._prev_err = self._cur_err - - self.num_iters = self._max_pass - print(f" Not converged after {self._max_pass} pass(es).") - - - def fit_transform( - self, - mats: List[torch.tensor], - ): - self.fit(mats) - return self.W diff --git a/nmf/inmf_models/__init__.py b/nmf/inmf_models/__init__.py new file mode 100644 index 0000000..7edd1f7 --- /dev/null +++ b/nmf/inmf_models/__init__.py @@ -0,0 +1,6 @@ +from ._inmf_batch_hals import INMFBatchHALS +from ._inmf_batch_mu import INMFBatchMU +from ._inmf_batch_nnls_bpp import INMFBatchNnlsBpp +from ._inmf_online_hals import INMFOnlineHALS +from ._inmf_online_mu import INMFOnlineMU +from ._inmf_online_nnls_bpp import INMFOnlineNnlsBpp diff --git a/nmf/_inmf_base.py b/nmf/inmf_models/_inmf_base.py similarity index 63% rename from nmf/_inmf_base.py rename to nmf/inmf_models/_inmf_base.py index 3d131aa..86cd502 100644 --- a/nmf/_inmf_base.py +++ b/nmf/inmf_models/_inmf_base.py @@ -3,19 +3,23 @@ from typing import List, Union + class INMFBase: def __init__( self, n_components: int, - lam: float = 5., - init: str = 'random', - tol: float = 1e-4, - random_state: int = 0, - fp_precision: Union[str, torch.dtype] = 'float', - device_type: str = 'cpu', + lam: float, + init: str, + tol: float, + random_state: int, + fp_precision: Union[str, torch.dtype], + device_type: str, ): self._n_components = n_components + + assert init in ['norm', 'uniform'], "Initialization method must be chosen from ['norm', 'uniform']!" self._init_method = init + self._lambda = lam self._tol = tol self._random_state = random_state @@ -32,28 +36,29 @@ def __init__( def _initialize_W_H_V(self): - ##W = torch.zeros((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type) - ##self.H = [] - ##self.V = [] - - # Random initialization - ##for k in range(self._n_batches): - ## avg = torch.sqrt(self.X[k].mean() / self._n_components) - ## H = torch.abs(avg * torch.randn((self.X[k].shape[0], self._n_components), dtype=self._tensor_dtype, device=self._device_type)) - ## V = torch.abs(0.5 * avg * torch.randn((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type)) - ## self.H.append(H) - ## self.V.append(V) - ## W += torch.abs(0.5 * avg * torch.randn((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type)) - ##W /= self._n_batches - ##self.W = W - self.W = 2.0 * torch.rand((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type) + self.W = torch.zeros((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type) self.H = [] self.V = [] - for k in range(self._n_batches): - H = 2.0 * torch.rand((self.X[k].shape[0], self._n_components), dtype=self._tensor_dtype, device=self._device_type) - V = 2.0 * torch.rand((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type) - self.H.append(H) - self.V.append(V) + + if self._init_method == 'norm': + self.W = torch.zeros((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type) + for k in range(self._n_batches): + avg = torch.sqrt(self.X[k].mean() / self._n_components) + H = torch.abs(avg * torch.randn((self.X[k].shape[0], self._n_components), dtype=self._tensor_dtype, device=self._device_type)) + V = torch.abs(0.5 * avg * torch.randn((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type)) + self.H.append(H) + self.V.append(V) + self.W += torch.abs(0.5 * avg * torch.randn((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type)) + self.W /= self._n_batches + else: + self.W.uniform_(0, 2) + for k in range(self._n_batches): + H = torch.zeros((self.X[k].shape[0], self._n_components), dtype=self._tensor_dtype, device=self._device_type) + H.uniform_(0, 2) + V = torch.zeros((self._n_components, self._n_features), dtype=self._tensor_dtype, device=self._device_type) + V.uniform_(0, 2) + self.H.append(H) + self.V.append(V) def _trace(self, A, B): @@ -75,7 +80,7 @@ def _cast_tensor(self, X): if not isinstance(X, torch.Tensor): if self._device_type == 'cpu' and ((self._device_type == torch.float32 and X.dtype == numpy.float32) or (self._device_type == torch.double and X.dtype == numpy.float64)): X = torch.from_numpy(X) - else: + else: X = torch.tensor(X, dtype=self._tensor_dtype, device=self._device_type) else: if self._device_type != 'cpu' and (not X.is_cuda): @@ -117,4 +122,4 @@ def fit_transform( mats: List[torch.tensor], ): self.fit(mats) - return self.W + return self.H diff --git a/nmf/_inmf_batch_base.py b/nmf/inmf_models/_inmf_batch_base.py similarity index 86% rename from nmf/_inmf_batch_base.py rename to nmf/inmf_models/_inmf_batch_base.py index b55c59d..a681608 100644 --- a/nmf/_inmf_batch_base.py +++ b/nmf/inmf_models/_inmf_batch_base.py @@ -3,17 +3,18 @@ from ._inmf_base import INMFBase from typing import List, Union + class INMFBatchBase(INMFBase): def __init__( self, n_components: int, - lam: float = 5., - init: str = 'random', - tol: float = 1e-4, - random_state: int = 0, - fp_precision: Union[str, torch.dtype] = 'float', - device_type: str = 'cpu', - max_iter: int = 200, + lam: float, + init: str, + tol: float, + random_state: int, + fp_precision: Union[str, torch.dtype], + device_type: str, + max_iter: int, ): super().__init__( n_components=n_components, @@ -24,7 +25,7 @@ def __init__( fp_precision=fp_precision, device_type=device_type, ) - + self._max_iter = max_iter diff --git a/nmf/_inmf_batch_hals.py b/nmf/inmf_models/_inmf_batch_hals.py similarity index 94% rename from nmf/_inmf_batch_hals.py rename to nmf/inmf_models/_inmf_batch_hals.py index 4a2fce6..baaddd4 100644 --- a/nmf/_inmf_batch_hals.py +++ b/nmf/inmf_models/_inmf_batch_hals.py @@ -3,19 +3,20 @@ from ._inmf_batch_base import INMFBatchBase from typing import List, Union + class INMFBatchHALS(INMFBatchBase): def __init__( self, n_components: int, - lam: float = 5., - init: str = 'random', - tol: float = 1e-4, - random_state: int = 0, - fp_precision: Union[str, torch.dtype] = 'float', - device_type: str = 'cpu', - max_iter: int = 200, - hals_tol: float = 0.0008, - hals_max_iter: int = 200, + lam: float, + init: str, + tol: float, + random_state: int, + fp_precision: Union[str, torch.dtype], + device_type: str, + max_iter: int, + hals_tol: float, + hals_max_iter: int, ): super().__init__( n_components=n_components, @@ -27,7 +28,7 @@ def __init__( device_type=device_type, max_iter=max_iter, ) - + self._zero = torch.tensor(0.0, dtype=self._tensor_dtype, device=self._device_type) self._hals_tol = hals_tol self._hals_max_iter = hals_max_iter diff --git a/nmf/_inmf_batch_mu.py b/nmf/inmf_models/_inmf_batch_mu.py similarity index 98% rename from nmf/_inmf_batch_mu.py rename to nmf/inmf_models/_inmf_batch_mu.py index 4cdd014..6d6900d 100644 --- a/nmf/_inmf_batch_mu.py +++ b/nmf/inmf_models/_inmf_batch_mu.py @@ -1,7 +1,8 @@ import torch from ._inmf_batch_base import INMFBatchBase -from typing import List, Union +from typing import List + class INMFBatchMU(INMFBatchBase): def _update_matrix(self, mat, numer, denom): diff --git a/nmf/_inmf_batch_nnls_bpp.py b/nmf/inmf_models/_inmf_batch_nnls_bpp.py similarity index 97% rename from nmf/_inmf_batch_nnls_bpp.py rename to nmf/inmf_models/_inmf_batch_nnls_bpp.py index 1fe06f3..7a1801e 100644 --- a/nmf/_inmf_batch_nnls_bpp.py +++ b/nmf/inmf_models/_inmf_batch_nnls_bpp.py @@ -1,8 +1,8 @@ import torch from ._inmf_batch_base import INMFBatchBase -from ._nnls_bpp import nnls_bpp -from typing import List, Union +from ..utils import nnls_bpp +from typing import List class INMFBatchNnlsBpp(INMFBatchBase): diff --git a/nmf/_inmf_online_base.py b/nmf/inmf_models/_inmf_online_base.py similarity index 87% rename from nmf/_inmf_online_base.py rename to nmf/inmf_models/_inmf_online_base.py index 679b71b..8ee0dcc 100644 --- a/nmf/_inmf_online_base.py +++ b/nmf/inmf_models/_inmf_online_base.py @@ -3,18 +3,19 @@ from ._inmf_base import INMFBase from typing import List, Union + class INMFOnlineBase(INMFBase): def __init__( self, n_components: int, - lam: float = 5., - init: str = 'random', - tol: float = 1e-4, - random_state: int = 0, - fp_precision: Union[str, torch.dtype] = 'float', - device_type: str = 'cpu', - max_pass: int = 20, - chunk_size: int = 5000, + lam: float, + init: str, + tol: float, + random_state: int, + fp_precision: Union[str, torch.dtype], + device_type: str, + max_pass: int, + chunk_size: int, ): super().__init__( n_components=n_components, diff --git a/nmf/_inmf_online_hals.py b/nmf/inmf_models/_inmf_online_hals.py similarity index 95% rename from nmf/_inmf_online_hals.py rename to nmf/inmf_models/_inmf_online_hals.py index 34818ae..499acee 100644 --- a/nmf/_inmf_online_hals.py +++ b/nmf/inmf_models/_inmf_online_hals.py @@ -3,22 +3,23 @@ from ._inmf_online_base import INMFOnlineBase from typing import List, Union + class INMFOnlineHALS(INMFOnlineBase): def __init__( self, n_components: int, - lam: float = 5., - init: str = 'random', - tol: float = 1e-4, - random_state: int = 0, - fp_precision: Union[str, torch.dtype] = 'float', - device_type: str = 'cpu', - max_pass: int = 20, - chunk_size: int = 5000, - chunk_max_iter: int = 200, - h_tol: float = 0.01, - v_tol: float = 0.1, - w_tol: float = 0.01, + lam: float, + init: str, + tol: float, + random_state: int, + fp_precision: Union[str, torch.dtype], + device_type: str, + max_pass: int, + chunk_size: int, + chunk_max_iter: int, + h_tol: float, + v_tol: float, + w_tol: float, ): super().__init__( n_components=n_components, @@ -78,7 +79,7 @@ def _update_one_pass(self): denom = WVWVT[l, l] + self._lambda * VVT[l, l] else: numer = xWVT[:, l] - h @ WVWVT[:, l] - denom = WVWVT[l, l] + denom = WVWVT[l, l] hvec = h[:, l] + numer / denom if torch.isnan(hvec).sum() > 0: hvec[:] = 0.0 # divide zero error: set h_new to 0 @@ -181,7 +182,7 @@ def _update_H_V(self): denom = WVWVT[l, l] + self._lambda * VVT[l, l] else: numer = xWVT[:, l] - h @ WVWVT[:, l] - denom = WVWVT[l, l] + denom = WVWVT[l, l] hvec = h[:, l] + numer / denom if torch.isnan(hvec).sum() > 0: hvec[:] = 0.0 # divide zero error: set h_new to 0 @@ -248,7 +249,7 @@ def _update_H(self): denom = WVWVT[l, l] + self._lambda * VVT[l, l] else: numer = xWVT[:, l] - h @ WVWVT[:, l] - denom = WVWVT[l, l] + denom = WVWVT[l, l] hvec = h[:, l] + numer / denom if torch.isnan(hvec).sum() > 0: hvec[:] = 0.0 # divide zero error: set h_new to 0 @@ -264,7 +265,7 @@ def _update_H(self): hth = h.T @ h sum_h_err += self._h_err(h, hth, WVWVT, xWVT, VVT) - + i += self._chunk_size return torch.sqrt(sum_h_err + self._SSX) diff --git a/nmf/_inmf_online_mu.py b/nmf/inmf_models/_inmf_online_mu.py similarity index 95% rename from nmf/_inmf_online_mu.py rename to nmf/inmf_models/_inmf_online_mu.py index 71b31fa..5efed63 100644 --- a/nmf/_inmf_online_mu.py +++ b/nmf/inmf_models/_inmf_online_mu.py @@ -3,22 +3,23 @@ from ._inmf_online_base import INMFOnlineBase from typing import List, Union + class INMFOnlineMU(INMFOnlineBase): def __init__( self, n_components: int, - lam: float = 5., - init: str = 'random', - tol: float = 1e-4, - random_state: int = 0, - fp_precision: Union[str, torch.dtype] = 'float', - device_type: str = 'cpu', - max_pass: int = 20, - chunk_size: int = 5000, - chunk_max_iter: int = 200, - h_tol: float = 0.01, - v_tol: float = 0.1, - w_tol: float = 0.01, + lam: float, + init: str, + tol: float, + random_state: int, + fp_precision: Union[str, torch.dtype], + device_type: str, + max_pass: int, + chunk_size: int, + chunk_max_iter: int, + h_tol: float, + v_tol: float, + w_tol: float, ): super().__init__( n_components=n_components, diff --git a/nmf/_inmf_online_nnls_bpp.py b/nmf/inmf_models/_inmf_online_nnls_bpp.py similarity index 99% rename from nmf/_inmf_online_nnls_bpp.py rename to nmf/inmf_models/_inmf_online_nnls_bpp.py index 28d1f98..d78a928 100644 --- a/nmf/_inmf_online_nnls_bpp.py +++ b/nmf/inmf_models/_inmf_online_nnls_bpp.py @@ -1,8 +1,9 @@ import torch from ._inmf_online_base import INMFOnlineBase -from ._nnls_bpp import nnls_bpp -from typing import List, Union +from ..utils import nnls_bpp +from typing import List + class INMFOnlineNnlsBpp(INMFOnlineBase): def _update_one_pass(self): diff --git a/nmf/nmf.py b/nmf/nmf.py index 2056289..f849ee5 100644 --- a/nmf/nmf.py +++ b/nmf/nmf.py @@ -1,9 +1,9 @@ import numpy as np import torch -from typing import Union, Tuple - -from nmf import NMFBatchMU, NMFBatchHALS, NMFBatchNnlsBpp, NMFOnlineMU, NMFOnlineHALS, NMFOnlineNnlsBpp +from typing import List, Union, Tuple, Optional +from .nmf_models import NMFBatchMU, NMFBatchHALS, NMFBatchNnlsBpp, NMFOnlineMU, NMFOnlineHALS, NMFOnlineNnlsBpp +from .inmf_models import INMFBatchHALS, INMFBatchMU, INMFBatchNnlsBpp, INMFOnlineHALS, INMFOnlineMU, INMFOnlineNnlsBpp def run_nmf( X: Union[np.array, torch.tensor], @@ -27,7 +27,7 @@ def run_nmf( online_chunk_size: int = 5000, online_chunk_max_iter: int = 200, online_h_tol: float = 0.05, - online_w_tol: float = 0.05, + online_w_tol: float = 0.05, ) -> Tuple[np.array, np.array, float]: """ Perform Non-negative Matrix Factorization (NMF). @@ -55,7 +55,7 @@ def run_nmf( :math:`||A||_{Fro}^2 = \\sum_{i, j} A_{ij}^2` (Frobenius norm) - NMF uses the multiplicative update (MU) solver, either a batch version or an online version specified in ``update_method`` parameter, to minimize the objective function. + NMF uses various solvers (specified in ``algo`` parameter), in either batch or online mode (specified in ``mode`` parameter), to minimize this objective function. Parameters ---------- @@ -76,7 +76,7 @@ def run_nmf( algo: ``str``, optional, default: ``hals`` Choose from ``mu`` (Multiplicative Update), ``hals`` (Hierarchical Alternative Least Square) and ``bpp`` (alternative non-negative least squares with Block Principal Pivoting method). mode: ``str``, optional, default: ``batch`` - Learning mode. Choose from ``batch`` and ``online``. Notice that ``online`` only works when ``beta=2.0``. For other beta loss, it switches back to ``batch`` method. + Learning mode. Choose from ``batch`` and ``online``. Notice that ``online`` only works when ``beta=2.0``. For other beta loss, it switches back to ``batch`` method. tol: ``float``, optional, default: ``1e-4`` The toleration used for convergence check. random_state: ``int``, optional, default: ``0`` @@ -102,7 +102,7 @@ def run_nmf( batch_hals_tol: ``float``, optional, default: ``0.05`` For HALS, we have the option of using HALS to mimic BPP for a possible better loss. The mimic works as follows: update H by HALS several iterations until the maximal relative change < batch_hals_tol. Then update W similarly. batch_hals_max_iter: ``int``, optional, default: ``200`` - Maximal iterations of updating H & W for mimic BPP. If this parameter set to 1, it is the standard HALS. + Maximal iterations of updating H & W for mimic BPP. If this parameter set to 1, it is the standard HALS. online_max_pass: ``int``, optional, default: ``20`` The maximum number of online passes of all data to perform. online_chunk_size: ``int``, optional, default: ``5000`` @@ -183,7 +183,7 @@ def run_nmf( beta_loss=beta_loss, tol=tol, random_state=random_state, - **kwargs + **kwargs ) H = model.fit_transform(X) @@ -191,3 +191,165 @@ def run_nmf( err = model.reconstruction_err return H.cpu().numpy(), W.cpu().numpy(), err.cpu().numpy() + + +def integrative_nmf( + X: List[Union[np.array, torch.tensor]], + n_components: int, + init: Optional[str] = None, + algo: str = "hals", + mode: str = "batch", + tol: float = 1e-4, + random_state: int = 0, + use_gpu: bool = False, + lam: float = 5., + fp_precision: Union[str, torch.dtype] = "float", + batch_max_iter: int = 200, + batch_hals_tol: float = 0.0008, + batch_hals_max_iter: int = 200, + online_max_pass: int = 20, + online_chunk_size: int = 5000, + online_chunk_max_iter: int = 200, + online_h_tol: float = 0.01, + online_v_tol: float = 0.1, + online_w_tol: float = 0.01, +) -> Tuple[List[np.array], np.array, List[np.array], float]: + """ + Run integrative Non-negative Matrix Factorization (iNMF). + + Given a list of non-negative matrices X, perform integration using NMF. + It is useful for data integration, gene program extraction in Genomics, etc. + + The objective function is + + .. math:: + + \\sum_{k}||X_k - H_k(W+V_k)||_{Fro}^2 + \\lambda * \\sum_{k}||H_kV_k||_1 + + where + + :math:`||vec(A)||_1 = \\sum_{i, j} abs(A_{ij})` (Element-wise L1 norm) + + :math:`||A||_{Fro}^2 = \\sum_{i, j} A_{ij}^2` (Frobenius norm) + + iNMF uses various solvers (specified in ``algo`` parameter), either in batch or online mode (specified in ``mode`` parameter), to minimize this objective function. + + Parameters + ---------- + + X: List of ``numpy.array`` or ``torch.tensor`` + The input list of non-negative matrices of shape (n_samples_i, n_features), one per batch. The n_samples_i is number of samples in batch i, and all batches must have the same number of features. + n_components: ``int`` + Number of components achieved after iNMF. + init: ``str``, optional, default: ``None`` + Method for initialization on H, W, and V matrices. Available options are: ``norm``, ``uniform``, meaning using random numbers generated from Normal or Uniform distribution. + If ``None``, use ``norm`` for online mode, while ``uniform`` for batch mode, in order to achieve best performance. + algo: ``str``, optional, default: ``hals`` + Choose from ``mu`` (Multiplicative Update), ``hals`` (Hierarchical Alternative Least Square) and ``bpp`` (alternative non-negative least squares with Block Principal Pivoting method). + mode: ``str``, optional, default: ``batch`` + Learning mode. Choose from ``batch`` and ``online``. + tol: ``float``, optional, default: ``1e-4`` + The toleration used for convergence check. + random_state: ``int``, optional, default: ``0`` + The random state used for reproducibility on the results. + use_gpu: ``bool``, optional, default: ``False`` + If ``True``, use GPU if available. Otherwise, use CPU only. + lam: ``float``, optional, default: ``5.0`` + The coefficient for regularization terms. If ``0``, then no regularization will be performed. + fp_precision: ``str``, optional, default: ``float`` + The numeric precision on the results. + If ``float``, set precision to ``torch.float``; if ``double``, set precision to ``torch.double``. + Alternatively, choose Pytorch's `torch dtype `_ of your own. + batch_max_iter: ``int``, optional, default: ``200`` + The maximum number of iterations to perform for batch learning. + batch_hals_tol: ``float``, optional, default: ``0.0008`` + For HALS, we have the option of using HALS to mimic BPP for a possible better loss. The mimic works as follows: update H by HALS several iterations until the maximal relative change < batch_hals_tol. Then update W similarly. + batch_hals_max_iter: ``int``, optional, default: ``200`` + Maximal iterations of updating H & W for mimic BPP. If this parameter set to 1, it is the standard HALS. + online_max_pass: ``int``, optional, default: ``20`` + The maximum number of online passes of all data to perform. + online_chunk_size: ``int``, optional, default: ``5000`` + The chunk / mini-batch size for online learning. + online_chunk_max_iter: ``int``, optional, default: ``200`` + The maximum number of iterations for updating H or W in online learning. + online_h_tol: ``float``, optional, default: ``0.01`` + The tolerance for updating H in each chunk in online learning. + online_v_tol: ``float``, optional, default: ``0.1`` + The tolerance for updating V in each chunk in online learning. + online_w_tol: ``float``, optional, default: ``0.01`` + The tolerance for updating W in each chunk in online learning. + + Returns + ------- + H: List of ``numpy.array`` + List of the resulting decomposed matrices of shape (n_samples_i, n_components), where n_samples_i is the number of samples in batch i. + Each matrix represents the transformed coordinates of samples regarding components of the corresponding batch. + W: ``numpy.array`` + The resulting decomposed matrix of shape (n_components, n_features), which represents the shared information across the given batches in terms of features. + V: List of ``numpy.array`` + List of the resulting decomposed matrices of shape (n_components, n_features). + Each matrix represents the batch-specific information in terms of features of the corresponding batch. + reconstruction_error: ``float`` + The L2 Loss between the origin matrices X and their approximation after iNMF. + + Examples + -------- + >>> H, W, V, err = integrative_nmf(X, n_components=20) + >>> H, W, V, err = integrative_nmf(X, n_components=20, algo='bpp', mode='online') + """ + + device_type = 'cpu' + if use_gpu: + if torch.cuda.is_available(): + device_type = 'cuda' + print("Use GPU mode.") + else: + print("CUDA is not available on your machine. Use CPU mode instead.") + + if algo not in {'hals', 'mu', 'bpp'}: + raise ValueError("Parameter algo must be a valid value from ['hals', 'mu', 'bpp']!") + if mode not in {'batch', 'online'}: + raise ValueError("Parameter mode must be a valid value from ['batch', 'online']!") + + if init is None: + init = 'norm' if mode == 'online' else 'uniform' + + model_class = None + kwargs = {'device_type': device_type, 'lam': lam, 'fp_precision': fp_precision} + + if mode == 'batch': + kwargs['max_iter'] = batch_max_iter + if algo == 'hals': + model_class = INMFBatchHALS + kwargs['hals_tol'] = batch_hals_tol + kwargs['hals_max_iter'] = batch_hals_max_iter + elif algo == 'bpp': + model_class = INMFBatchNnlsBpp + else: + model_class = INMFBatchMU + else: + kwargs['max_pass'] = online_max_pass + kwargs['chunk_size'] = online_chunk_size + if algo == 'bpp': + model_class = INMFOnlineNnlsBpp + else: + model_class = INMFOnlineHALS if algo == 'hals' else INMFOnlineMU + kwargs['chunk_max_iter'] = online_chunk_max_iter + kwargs['h_tol'] = online_h_tol + kwargs['v_tol'] = online_v_tol + kwargs['w_tol'] = online_w_tol + + model = model_class( + n_components=n_components, + init=init, + tol=tol, + random_state=random_state, + **kwargs + ) + + H = model.fit_transform(X) + W = model.W + V = model.V + err = model.reconstruction_err + + return H, W, V, err diff --git a/nmf/nmf_models/__init__.py b/nmf/nmf_models/__init__.py new file mode 100644 index 0000000..efc2374 --- /dev/null +++ b/nmf/nmf_models/__init__.py @@ -0,0 +1,6 @@ +from ._nmf_batch_mu import NMFBatchMU +from ._nmf_batch_hals import NMFBatchHALS +from ._nmf_batch_nnls_bpp import NMFBatchNnlsBpp +from ._nmf_online_mu import NMFOnlineMU +from ._nmf_online_hals import NMFOnlineHALS +from ._nmf_online_nnls_bpp import NMFOnlineNnlsBpp diff --git a/nmf/_nmf_base.py b/nmf/nmf_models/_nmf_base.py similarity index 99% rename from nmf/_nmf_base.py rename to nmf/nmf_models/_nmf_base.py index 7d015ad..ecb6eaf 100644 --- a/nmf/_nmf_base.py +++ b/nmf/nmf_models/_nmf_base.py @@ -3,6 +3,7 @@ from typing import Union + class NMFBase: def __init__( self, @@ -137,7 +138,7 @@ def _cast_tensor(self, X): if not isinstance(X, torch.Tensor): if self._device_type == 'cpu' and ((self._device_type == torch.float32 and X.dtype == numpy.float32) or (self._device_type == torch.double and X.dtype == numpy.float64)): X = torch.from_numpy(X) - else: + else: X = torch.tensor(X, dtype=self._tensor_dtype, device=self._device_type) else: if self._device_type != 'cpu' and (not X.is_cuda): diff --git a/nmf/_nmf_batch_base.py b/nmf/nmf_models/_nmf_batch_base.py similarity index 99% rename from nmf/_nmf_batch_base.py rename to nmf/nmf_models/_nmf_batch_base.py index 2fe3167..62344b2 100644 --- a/nmf/_nmf_batch_base.py +++ b/nmf/nmf_models/_nmf_batch_base.py @@ -3,6 +3,7 @@ from ._nmf_base import NMFBase from typing import Union + class NMFBatchBase(NMFBase): def __init__( self, diff --git a/nmf/_nmf_batch_hals.py b/nmf/nmf_models/_nmf_batch_hals.py similarity index 100% rename from nmf/_nmf_batch_hals.py rename to nmf/nmf_models/_nmf_batch_hals.py diff --git a/nmf/_nmf_batch_mu.py b/nmf/nmf_models/_nmf_batch_mu.py similarity index 98% rename from nmf/_nmf_batch_mu.py rename to nmf/nmf_models/_nmf_batch_mu.py index 5e9eba4..d6b9471 100644 --- a/nmf/_nmf_batch_mu.py +++ b/nmf/nmf_models/_nmf_batch_mu.py @@ -1,6 +1,3 @@ -import torch - -from typing import Union from ._nmf_batch_base import NMFBatchBase diff --git a/nmf/_nmf_batch_nnls_bpp.py b/nmf/nmf_models/_nmf_batch_nnls_bpp.py similarity index 99% rename from nmf/_nmf_batch_nnls_bpp.py rename to nmf/nmf_models/_nmf_batch_nnls_bpp.py index 3939b59..1f6dcb9 100644 --- a/nmf/_nmf_batch_nnls_bpp.py +++ b/nmf/nmf_models/_nmf_batch_nnls_bpp.py @@ -1,7 +1,7 @@ import torch from ._nmf_batch_base import NMFBatchBase -from ._nnls_bpp import nnls_bpp +from ..utils import nnls_bpp from typing import Union diff --git a/nmf/_nmf_online_base.py b/nmf/nmf_models/_nmf_online_base.py similarity index 99% rename from nmf/_nmf_online_base.py rename to nmf/nmf_models/_nmf_online_base.py index 3cb6a2d..b023893 100644 --- a/nmf/_nmf_online_base.py +++ b/nmf/nmf_models/_nmf_online_base.py @@ -3,6 +3,7 @@ from ._nmf_base import NMFBase from typing import Union + class NMFOnlineBase(NMFBase): def __init__( self, @@ -21,7 +22,7 @@ def __init__( chunk_size: int = 5000, ): assert beta_loss == 2.0 # only work for F norm for now - + super().__init__( n_components=n_components, init=init, @@ -54,7 +55,7 @@ def _loss(self): """ calculate loss online by passing through all data""" i = 0 WWT = self.W @ self.W.T - + sum_h_err = torch.tensor(0.0, dtype=torch.double, device=self._device_type) # make sure sum_h_err is double to avoid summation errors while i < self.H.shape[0]: x = self.X[i:(i+self._chunk_size), :] diff --git a/nmf/_nmf_online_hals.py b/nmf/nmf_models/_nmf_online_hals.py similarity index 100% rename from nmf/_nmf_online_hals.py rename to nmf/nmf_models/_nmf_online_hals.py diff --git a/nmf/_nmf_online_mu.py b/nmf/nmf_models/_nmf_online_mu.py similarity index 100% rename from nmf/_nmf_online_mu.py rename to nmf/nmf_models/_nmf_online_mu.py diff --git a/nmf/_nmf_online_nnls_bpp.py b/nmf/nmf_models/_nmf_online_nnls_bpp.py similarity index 99% rename from nmf/_nmf_online_nnls_bpp.py rename to nmf/nmf_models/_nmf_online_nnls_bpp.py index 4bb52b7..ee9712c 100644 --- a/nmf/_nmf_online_nnls_bpp.py +++ b/nmf/nmf_models/_nmf_online_nnls_bpp.py @@ -1,7 +1,7 @@ import torch from ._nmf_online_base import NMFOnlineBase -from ._nnls_bpp import nnls_bpp +from ..utils import nnls_bpp from typing import Union diff --git a/nmf/utils/__init__.py b/nmf/utils/__init__.py new file mode 100644 index 0000000..a8c8bce --- /dev/null +++ b/nmf/utils/__init__.py @@ -0,0 +1 @@ +from ._nnls_bpp import nnls_bpp diff --git a/nmf/_nnls_bpp.py b/nmf/utils/_nnls_bpp.py similarity index 100% rename from nmf/_nnls_bpp.py rename to nmf/utils/_nnls_bpp.py diff --git a/tests/inmf_benchmark.py b/tests/inmf_benchmark.py new file mode 100644 index 0000000..0e788d5 --- /dev/null +++ b/tests/inmf_benchmark.py @@ -0,0 +1,69 @@ +import time +import torch + +#import pegasus as pg +import numpy as np +import pandas as pd + +from nmf import integrative_nmf + +def loss(X, H, W, V, lam): + res = 0.0 + for k in range(len(X)): + res += torch.norm(X[k].double() - H[k].double() @ (W.double() + V[k].double()), p=2)**2 + if lam > 0: + res += lam * torch.norm(H[k].double() @ V[k].double(), p=2)**2 + + return torch.sqrt(res) + +def run_test(mats, algo, mode, n_components, lam, seed, fp_precision, batch_max_iter): + print(f"{algo} {mode} Experiment...") + + torch.set_num_threads(12) + + ts_start = time.time() + H, W, V, err = integrative_nmf(mats, algo=algo, mode=mode, n_components=n_components, lam=lam, random_state=seed, fp_precision=fp_precision, batch_max_iter=batch_max_iter) + ts_end = time.time() + err_confirm = loss(mats, H, W, V, lam) + print(f"{algo} {mode} finishes in {ts_end - ts_start} s, with error {err} (confirmed with {err_confirm}).") + +#data = pg.read_input("MantonBM_nonmix.zarr.zip") +#pg.qc_metrics(data, min_genes=500, max_genes=6000, mito_prefix='MT-', percent_mito=10) +#pg.filter_data(data) +#pg.identify_robust_genes(data) +#pg.log_norm(data) +#pg.highly_variable_features(data, consider_batch=True) +#keyword = pg.select_features(data, features='highly_variable_features', standardize=True, max_value=10) +#X = (data.uns[keyword] + data.uns['stdzn_mean'] / data.uns['stdzn_std']).astype(np.float32) +#X[X < 0] = 0.0 +#np.save("inmf_data/counts.npy", X) +#data.obs[['Channel']].to_csv("inmf_data/metadata.csv", index=False, header=False) +X = np.load("inmf_data/subset/counts.npy") + +df = pd.read_csv("inmf_data/subset/metadata.csv", header=None) +df[0] = df[0].astype('category') +mats = [] +#for chan in data.obs['Channel'].cat.categories: +# x = X[df_obs.loc[df_obs['Channel']==chan].index, :].copy() +# mats.append(torch.tensor(x)) +for chan in df[0].cat.categories: + x = X[df.loc[df[0]==chan].index, :].copy() + print(x.shape) + mats.append(torch.tensor(x, dtype=torch.float)) + +print("Start iNMF...") +#rnd_seeds = [28728712, 39074257, 751935947, 700933753, 1315698701, 1096583738, 1381716902, 1862944882, 472642840, 530691960] +rnd_seeds = [0] +#rnd_seeds = [3365, 2217, 629, 715, 4289, 3849, 625, 6598, 8275, 9570] + +cnt = 0 +lam = 5.0 +algo_list = ['hals', 'bpp', 'mu'] +mode_list = ['batch', 'online'] +for seed in rnd_seeds: + cnt += 1 + print(f"{cnt}. Experiment with random seed {seed}...") + + for algo in algo_list: + for mode in mode_list: + run_test(mats, algo, mode, n_components=20, lam=lam, seed=seed, fp_precision='float', batch_max_iter=500) diff --git a/test.py b/tests/nmf_benchmark.py similarity index 100% rename from test.py rename to tests/nmf_benchmark.py