-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from lilab-bcb/develop
Finish up first version
- Loading branch information
Showing
36 changed files
with
3,486 additions
and
565 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,8 @@ | ||
*~ | ||
*.npy | ||
__pycache__/ | ||
ext_modules/*.c | ||
ext_modules/*.cpp | ||
nmf/cylib/*.so | ||
__pycache__/ | ||
build/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
exclude tests/* | ||
exclude archive/* | ||
exclude docs/** | ||
exclude build_linux.sh | ||
exclude wheel_build/* | ||
exclude .* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# nmf-torch | ||
============== | ||
NMF-Torch | ||
============== | ||
|
||
A PyTorch implementation on Non-negative Matrix Factorization. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import torch | ||
|
||
from ._inmf_base import INMFBase | ||
from typing import List, Union | ||
|
||
class INMFBatchHALSWrong(INMFBase): | ||
def __init__( | ||
self, | ||
n_components: int, | ||
lam: float = 5., | ||
init: str = 'random', | ||
tol: float = 1e-4, | ||
random_state: int = 0, | ||
fp_precision: Union[str, torch.dtype] = 'float', | ||
device_type: str = 'cpu', | ||
max_iter: int = 200, | ||
): | ||
super().__init__( | ||
n_components=n_components, | ||
lam=lam, | ||
init=init, | ||
tol=tol, | ||
random_state=random_state, | ||
fp_precision=fp_precision, | ||
device_type=device_type, | ||
) | ||
|
||
self._max_iter = max_iter | ||
self._zero = torch.tensor(0.0, dtype=self._tensor_dtype, device=self._device_type) | ||
|
||
|
||
def _update_H_V_W(self): | ||
W_numer = torch.zeros_like(self.W) | ||
W_denom = torch.zeros_like(self._HTH[0]) | ||
# Update Hs and Vs and calculate terms for updating W | ||
for k in range(self._n_batches): | ||
# Update H[k] | ||
for l in range(self._n_components): | ||
numer = self._XWVT[k][:, l] - self.H[k] @ self._WVWVT[k][:, l] | ||
if self._lambda > 0.0: | ||
denom = self._WVWVT[k][l, l] + self._lambda * self._VVT[k][l, l] | ||
h_new = self.H[k][:, l] * (self._WVWVT[k][l, l] / denom) + numer / denom | ||
else: | ||
h_new = self.H[k][:, l] + numer / self._WVWVT[k][l, l] | ||
if torch.isnan(h_new).sum() > 0: | ||
h_new[:] = 0.0 # divide zero error: set h_new to 0 | ||
else: | ||
h_new = h_new.maximum(self._zero) | ||
self.H[k][:, l] = h_new | ||
# Cache HTH | ||
self._HTH[k] = self.H[k].T @ self.H[k] | ||
|
||
# Update V[k] | ||
HTX = self.H[k].T @ self.X[k] | ||
for l in range(self._n_components): | ||
numer = HTX[l, :] - self._HTH[k][l, :] @ (self.W + self.V[k]) | ||
denom = 1.0 + self._lambda | ||
v_new = self.V[k][l, :] * (1.0 / denom) + numer / (denom * self._HTH[k][l, l]) | ||
if torch.isnan(v_new).sum() > 0: | ||
v_new[:] = 0.0 # divide zero error: set v_new to 0 | ||
else: | ||
v_new = v_new.maximum(self._zero) | ||
self.V[k][l, :] = v_new | ||
# Cache VVT | ||
if self._lambda > 0.0: | ||
self._VVT[k] = self.V[k] @ self.V[k].T | ||
|
||
# Update W numer and denomer | ||
W_numer += (HTX - self._HTH[k] @ self.V[k]) | ||
W_denom += self._HTH[k] | ||
|
||
# Update W | ||
for l in range(self._n_components): | ||
w_new = self.W[l, :] + (W_numer[l, :] - W_denom[l, :] @ self.W) / W_denom[l, l] | ||
if torch.isnan(w_new).sum() > 0: | ||
w_new[:] = 0.0 # divide zero error: set w_new to 0 | ||
else: | ||
w_new = w_new.maximum(self._zero) | ||
self.W[l, :] = w_new | ||
# Cache WVWVT and XWVT | ||
for k in range(self._n_batches): | ||
WV = self.W + self.V[k] | ||
self._WVWVT[k] = WV @ WV.T | ||
self._XWVT[k] = self.X[k] @ WV.T | ||
|
||
|
||
def fit( | ||
self, | ||
mats: List[torch.tensor], | ||
): | ||
super().fit(mats) | ||
|
||
# Batch update | ||
for i in range(self._max_iter): | ||
self._update_H_V_W() | ||
|
||
if (i + 1) % 10 == 0: | ||
self._cur_err = self._loss() | ||
print(f" niter={i+1}, loss={self._cur_err}.") | ||
if self._is_converged(self._prev_err, self._cur_err, self._init_err): | ||
self.num_iters = i + 1 | ||
print(f" Converged after {self.num_iters} iteration(s).") | ||
return | ||
|
||
self._prev_err = self._cur_err | ||
|
||
self.num_iters = self._max_iter | ||
print(f" Not converged after {self.num_iters} iteration(s).") | ||
|
||
|
||
def fit_transform( | ||
self, | ||
mats: List[torch.tensor], | ||
): | ||
self.fit(mats) | ||
return self.W |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# cython: language_level=3 | ||
# distutils: language = c++ | ||
|
||
### NNLS Block principal pivoting: Kim and Park et al 2011. | ||
|
||
import numpy as np | ||
import torch | ||
|
||
cimport cython | ||
|
||
from libcpp.vector cimport vector | ||
from libcpp.string cimport string | ||
from libcpp.unordered_map cimport unordered_map | ||
from cython.operator import dereference as deref | ||
from cython.operator import postincrement as pinc | ||
|
||
ctypedef unsigned char uint8 | ||
|
||
ctypedef fused array_type: | ||
float | ||
double | ||
|
||
|
||
cdef inline array_type _filter(array_type number, array_type tol): | ||
if number > -tol and number < tol: | ||
return 0.0 | ||
return number | ||
|
||
|
||
@cython.boundscheck(False) | ||
@cython.wraparound(False) | ||
cpdef _nnls_bpp(array_type[:, :] CTC, array_type[:, :] CTB, array_type[:, :] X, str device_type): | ||
# CTC = C.T @ C, CTB = C.T @ B, X.shape = (q, r) | ||
numpy_type = np.float32 if array_type is float else np.float64 | ||
torch_type = torch.float if array_type is float else torch.double | ||
cdef array_type tol = 1e-6 if array_type is float else 1e-12 | ||
|
||
cdef Py_ssize_t i, j, col_idx, row_idx | ||
cdef Py_ssize_t n_iter, fvsize, gvsize, uqsize, size_I, pos | ||
|
||
cdef int q = CTB.shape[0] | ||
cdef int r = CTB.shape[1] | ||
cdef int max_iter = 5 * q | ||
|
||
cdef int backup_cap = 3 # maximum back up tries | ||
cdef int[:] alpha = np.full(r, backup_cap, dtype=np.int32) # cap on back up rule | ||
cdef int[:] beta = np.full(r, q+1, dtype=np.int32) # number of infeasible variables | ||
|
||
### Initialization, setting G = 1-q | ||
cdef array_type[:, :] Y = np.zeros_like(CTB, dtype=numpy_type) | ||
cdef uint8[:, :] V = np.zeros((q, r), dtype=np.bool_) | ||
cdef int[:] Vsize = np.zeros((r,), dtype=np.int32) | ||
|
||
cdef vector[int] I | ||
|
||
cdef uint8[:, :] F = np.zeros((q, r), dtype=np.bool_) # y_F = 0, G = ~F, x_G = 0 | ||
|
||
CTC_L_tensor = torch.zeros((q, q), dtype=torch_type, device='cpu') | ||
cdef array_type[:, :] CTC_L = CTC_L_tensor.numpy() | ||
CTB_L_tensor = torch.zeros((q, r), dtype=torch_type, device='cpu') | ||
cdef array_type[:, :] CTB_L = CTB_L_tensor.numpy() | ||
|
||
cdef array_type[:, :] x | ||
cdef array_type[:, :] y | ||
|
||
cdef unordered_map[string, vector[int]] uniq_F | ||
cdef unordered_map[string, vector[int]].iterator it | ||
|
||
cdef string Fvec_str | ||
Fvec_str.resize(q, b' ') | ||
|
||
cdef vector[int] Fvec | ||
cdef vector[int] Gvec | ||
|
||
|
||
I.clear() | ||
for j in range(r): | ||
for i in range(q): | ||
X[i, j] = 0.0 | ||
Y[i, j] = -CTB[i, j] | ||
V[i, j] = Y[i, j] < 0 | ||
Vsize[j] += V[i, j] | ||
|
||
if Vsize[j] > 0: | ||
I.push_back(j) | ||
|
||
n_iter = 0 | ||
while I.size() > 0 and n_iter < max_iter: | ||
# Split indices in I into 3 cases: | ||
for col_idx in I: | ||
if Vsize[col_idx] < beta[col_idx]: | ||
# Case 1: Apply full exchange rule | ||
alpha[col_idx] = backup_cap | ||
beta[col_idx] = Vsize[col_idx] | ||
for i in range(q): | ||
F[i, col_idx] ^= V[i, col_idx] | ||
elif alpha[col_idx] > 0: | ||
# Case 2: Retry with full exchange rule | ||
alpha[col_idx] -= 1 | ||
for i in range(q): | ||
F[i, col_idx] ^= V[i, col_idx] | ||
else: | ||
# Case 3: Apply backup rule | ||
row_idx = 0 | ||
for i in range(q-1, -1, -1): | ||
if V[i, col_idx] > 0: | ||
row_idx = i | ||
break | ||
F[row_idx, col_idx] ^= 1 | ||
|
||
# Get unique F columns with indices mapping back to F. | ||
uniq_F.clear() | ||
for col_idx in I: | ||
for i in range(q): | ||
Fvec_str[i] = b'1' if F[i, col_idx] else b'0' | ||
|
||
it = uniq_F.find(Fvec_str) | ||
if it != uniq_F.end(): | ||
deref(it).second.push_back(col_idx) | ||
else: | ||
uniq_F[Fvec_str] = vector[int](1, col_idx) | ||
|
||
# Solve grouped normal equations | ||
it = uniq_F.begin() | ||
while it != uniq_F.end(): | ||
Fvec.clear() | ||
Gvec.clear() | ||
for i in range(q): | ||
if deref(it).first[i] == b'1': | ||
Fvec.push_back(i) | ||
else: | ||
Gvec.push_back(i) | ||
|
||
fvsize = Fvec.size() | ||
gvsize = Gvec.size() | ||
uqsize = deref(it).second.size() | ||
|
||
if fvsize > 0: | ||
# CTC_L = CTC[Fvec, Fvec] | ||
for i in range(fvsize): | ||
for j in range(fvsize): | ||
CTC_L[i, j] = CTC[Fvec[i], Fvec[j]] | ||
L_tensor = torch.cholesky(CTC_L_tensor[0:fvsize, 0:fvsize]) if device_type == 'cpu' else torch.cholesky(CTC_L_tensor[0:fvsize, 0:fvsize].cuda()) | ||
# CTB_L = CTB[Fvec, Ii] | ||
for i in range(fvsize): | ||
for j in range(uqsize): | ||
CTB_L[i, j] = CTB[Fvec[i], deref(it).second[j]] | ||
x_tensor = torch.cholesky_solve(CTB_L_tensor[0:fvsize, 0:uqsize], L_tensor) if device_type == 'cpu' else torch.cholesky_solve(CTB_L_tensor[0:fvsize, 0:uqsize].cuda(), L_tensor) | ||
x = x_tensor.cpu().numpy() | ||
# clean up | ||
for i in range(fvsize): | ||
for j in range(uqsize): | ||
X[Fvec[i], deref(it).second[j]] = _filter(x[i, j], tol) | ||
Y[Fvec[i], deref(it).second[j]] = 0.0 | ||
|
||
if gvsize > 0: | ||
if fvsize > 0: | ||
# CTC_L = CTC[~Fvec, Fvec] | ||
for i in range(gvsize): | ||
for j in range(fvsize): | ||
CTC_L[i, j] = CTC[Gvec[i], Fvec[j]] | ||
y_tensor = (CTC_L_tensor[0:gvsize, 0:fvsize] @ x_tensor) if device_type == 'cpu' else (CTC_L_tensor[0:gvsize, 0:fvsize].cuda() @ x_tensor) | ||
y = y_tensor.cpu().numpy() | ||
|
||
for i in range(gvsize): | ||
row_idx = Gvec[i] | ||
for j in range(uqsize): | ||
col_idx = deref(it).second[j] | ||
Y[row_idx, col_idx] = _filter(y[i,j] - CTB[row_idx, col_idx], tol) | ||
X[row_idx, col_idx] = 0.0 | ||
else: | ||
for i in range(gvsize): | ||
row_idx = Gvec[i] | ||
for j in range(uqsize): | ||
col_idx = deref(it).second[j] | ||
Y[row_idx, col_idx] = _filter(-CTB[row_idx, col_idx], tol) | ||
X[row_idx, col_idx] = 0.0 | ||
|
||
pinc(it) | ||
|
||
size_I = 0 | ||
for j in range(I.size()): | ||
pos = I[j] | ||
Vsize[pos] = 0 | ||
for i in range(q): | ||
V[i, pos] = ((X[i, pos] < 0.0) & F[i, pos]) | ((Y[i, pos] < 0.0) & (F[i, pos] == 0)) | ||
Vsize[pos] += V[i, pos] | ||
if Vsize[pos] > 0: | ||
I[size_I] = pos | ||
size_I += 1 | ||
I.resize(size_I) | ||
|
||
n_iter += 1 | ||
|
||
return n_iter if I.size() == 0 else -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,27 @@ | ||
from .nmf import NMF | ||
|
||
#from ._nmf_batch_mu import NMFBatchMU | ||
#from ._nmf_batch_hals import NMFBatchHALS | ||
#from ._nmf_batch_nnls_bpp import NMFBatchNnlsBpp | ||
#from ._nmf_online_mu import NMFOnlineMU | ||
#from ._nmf_online_hals import NMFOnlineHALS | ||
#from ._nmf_online_nnls_bpp import NMFOnlineNnlsBpp | ||
|
||
from .nmf import run_nmf, integrative_nmf | ||
|
||
#from ._inmf_batch_mu import INMFBatchMU | ||
#from ._inmf_batch_hals import INMFBatchHALS | ||
#from ._inmf_batch_nnls_bpp import INMFBatchNnlsBpp | ||
#from ._inmf_online_mu import INMFOnlineMU | ||
#from ._inmf_online_hals import INMFOnlineHALS | ||
#from ._inmf_online_nnls_bpp import INMFOnlineNnlsBpp | ||
|
||
try: | ||
from importlib.metadata import version, PackageNotFoundError | ||
except ImportError: # < Python 3.8: Use backport module | ||
from importlib_metadata import version, PackageNotFoundError | ||
|
||
try: | ||
__version__ = version('nmf-torch') | ||
del version | ||
except PackageNotFoundError: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
This folder contains compiled Cython libraries. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from ._inmf_batch_hals import INMFBatchHALS | ||
from ._inmf_batch_mu import INMFBatchMU | ||
from ._inmf_batch_nnls_bpp import INMFBatchNnlsBpp | ||
from ._inmf_online_hals import INMFOnlineHALS | ||
from ._inmf_online_mu import INMFOnlineMU | ||
from ._inmf_online_nnls_bpp import INMFOnlineNnlsBpp |
Oops, something went wrong.