Skip to content

Commit

Permalink
Merge pull request #3 from lilab-bcb/develop
Browse files Browse the repository at this point in the history
Finish up first version
  • Loading branch information
yihming authored Jun 20, 2021
2 parents b908570 + 7b09d27 commit 532dc1c
Show file tree
Hide file tree
Showing 36 changed files with 3,486 additions and 565 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
*~
*.pdf
*.npy
__pycache__/
ext_modules/*.c
ext_modules/*.cpp
nmf/cylib/*.so
__pycache__/
build/
6 changes: 6 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
exclude tests/*
exclude archive/*
exclude docs/**
exclude build_linux.sh
exclude wheel_build/*
exclude .*
4 changes: 3 additions & 1 deletion README.md → README.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# nmf-torch
==============
NMF-Torch
==============

A PyTorch implementation on Non-negative Matrix Factorization.
116 changes: 116 additions & 0 deletions archive/_inmf_batch_hals_wrong.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch

from ._inmf_base import INMFBase
from typing import List, Union

class INMFBatchHALSWrong(INMFBase):
def __init__(
self,
n_components: int,
lam: float = 5.,
init: str = 'random',
tol: float = 1e-4,
random_state: int = 0,
fp_precision: Union[str, torch.dtype] = 'float',
device_type: str = 'cpu',
max_iter: int = 200,
):
super().__init__(
n_components=n_components,
lam=lam,
init=init,
tol=tol,
random_state=random_state,
fp_precision=fp_precision,
device_type=device_type,
)

self._max_iter = max_iter
self._zero = torch.tensor(0.0, dtype=self._tensor_dtype, device=self._device_type)


def _update_H_V_W(self):
W_numer = torch.zeros_like(self.W)
W_denom = torch.zeros_like(self._HTH[0])
# Update Hs and Vs and calculate terms for updating W
for k in range(self._n_batches):
# Update H[k]
for l in range(self._n_components):
numer = self._XWVT[k][:, l] - self.H[k] @ self._WVWVT[k][:, l]
if self._lambda > 0.0:
denom = self._WVWVT[k][l, l] + self._lambda * self._VVT[k][l, l]
h_new = self.H[k][:, l] * (self._WVWVT[k][l, l] / denom) + numer / denom
else:
h_new = self.H[k][:, l] + numer / self._WVWVT[k][l, l]
if torch.isnan(h_new).sum() > 0:
h_new[:] = 0.0 # divide zero error: set h_new to 0
else:
h_new = h_new.maximum(self._zero)
self.H[k][:, l] = h_new
# Cache HTH
self._HTH[k] = self.H[k].T @ self.H[k]

# Update V[k]
HTX = self.H[k].T @ self.X[k]
for l in range(self._n_components):
numer = HTX[l, :] - self._HTH[k][l, :] @ (self.W + self.V[k])
denom = 1.0 + self._lambda
v_new = self.V[k][l, :] * (1.0 / denom) + numer / (denom * self._HTH[k][l, l])
if torch.isnan(v_new).sum() > 0:
v_new[:] = 0.0 # divide zero error: set v_new to 0
else:
v_new = v_new.maximum(self._zero)
self.V[k][l, :] = v_new
# Cache VVT
if self._lambda > 0.0:
self._VVT[k] = self.V[k] @ self.V[k].T

# Update W numer and denomer
W_numer += (HTX - self._HTH[k] @ self.V[k])
W_denom += self._HTH[k]

# Update W
for l in range(self._n_components):
w_new = self.W[l, :] + (W_numer[l, :] - W_denom[l, :] @ self.W) / W_denom[l, l]
if torch.isnan(w_new).sum() > 0:
w_new[:] = 0.0 # divide zero error: set w_new to 0
else:
w_new = w_new.maximum(self._zero)
self.W[l, :] = w_new
# Cache WVWVT and XWVT
for k in range(self._n_batches):
WV = self.W + self.V[k]
self._WVWVT[k] = WV @ WV.T
self._XWVT[k] = self.X[k] @ WV.T


def fit(
self,
mats: List[torch.tensor],
):
super().fit(mats)

# Batch update
for i in range(self._max_iter):
self._update_H_V_W()

if (i + 1) % 10 == 0:
self._cur_err = self._loss()
print(f" niter={i+1}, loss={self._cur_err}.")
if self._is_converged(self._prev_err, self._cur_err, self._init_err):
self.num_iters = i + 1
print(f" Converged after {self.num_iters} iteration(s).")
return

self._prev_err = self._cur_err

self.num_iters = self._max_iter
print(f" Not converged after {self.num_iters} iteration(s).")


def fit_transform(
self,
mats: List[torch.tensor],
):
self.fit(mats)
return self.W
195 changes: 195 additions & 0 deletions ext_modules/nnls_bpp_utils.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# cython: language_level=3
# distutils: language = c++

### NNLS Block principal pivoting: Kim and Park et al 2011.

import numpy as np
import torch

cimport cython

from libcpp.vector cimport vector
from libcpp.string cimport string
from libcpp.unordered_map cimport unordered_map
from cython.operator import dereference as deref
from cython.operator import postincrement as pinc

ctypedef unsigned char uint8

ctypedef fused array_type:
float
double


cdef inline array_type _filter(array_type number, array_type tol):
if number > -tol and number < tol:
return 0.0
return number


@cython.boundscheck(False)
@cython.wraparound(False)
cpdef _nnls_bpp(array_type[:, :] CTC, array_type[:, :] CTB, array_type[:, :] X, str device_type):
# CTC = C.T @ C, CTB = C.T @ B, X.shape = (q, r)
numpy_type = np.float32 if array_type is float else np.float64
torch_type = torch.float if array_type is float else torch.double
cdef array_type tol = 1e-6 if array_type is float else 1e-12

cdef Py_ssize_t i, j, col_idx, row_idx
cdef Py_ssize_t n_iter, fvsize, gvsize, uqsize, size_I, pos

cdef int q = CTB.shape[0]
cdef int r = CTB.shape[1]
cdef int max_iter = 5 * q

cdef int backup_cap = 3 # maximum back up tries
cdef int[:] alpha = np.full(r, backup_cap, dtype=np.int32) # cap on back up rule
cdef int[:] beta = np.full(r, q+1, dtype=np.int32) # number of infeasible variables

### Initialization, setting G = 1-q
cdef array_type[:, :] Y = np.zeros_like(CTB, dtype=numpy_type)
cdef uint8[:, :] V = np.zeros((q, r), dtype=np.bool_)
cdef int[:] Vsize = np.zeros((r,), dtype=np.int32)

cdef vector[int] I

cdef uint8[:, :] F = np.zeros((q, r), dtype=np.bool_) # y_F = 0, G = ~F, x_G = 0

CTC_L_tensor = torch.zeros((q, q), dtype=torch_type, device='cpu')
cdef array_type[:, :] CTC_L = CTC_L_tensor.numpy()
CTB_L_tensor = torch.zeros((q, r), dtype=torch_type, device='cpu')
cdef array_type[:, :] CTB_L = CTB_L_tensor.numpy()

cdef array_type[:, :] x
cdef array_type[:, :] y

cdef unordered_map[string, vector[int]] uniq_F
cdef unordered_map[string, vector[int]].iterator it

cdef string Fvec_str
Fvec_str.resize(q, b' ')

cdef vector[int] Fvec
cdef vector[int] Gvec


I.clear()
for j in range(r):
for i in range(q):
X[i, j] = 0.0
Y[i, j] = -CTB[i, j]
V[i, j] = Y[i, j] < 0
Vsize[j] += V[i, j]

if Vsize[j] > 0:
I.push_back(j)

n_iter = 0
while I.size() > 0 and n_iter < max_iter:
# Split indices in I into 3 cases:
for col_idx in I:
if Vsize[col_idx] < beta[col_idx]:
# Case 1: Apply full exchange rule
alpha[col_idx] = backup_cap
beta[col_idx] = Vsize[col_idx]
for i in range(q):
F[i, col_idx] ^= V[i, col_idx]
elif alpha[col_idx] > 0:
# Case 2: Retry with full exchange rule
alpha[col_idx] -= 1
for i in range(q):
F[i, col_idx] ^= V[i, col_idx]
else:
# Case 3: Apply backup rule
row_idx = 0
for i in range(q-1, -1, -1):
if V[i, col_idx] > 0:
row_idx = i
break
F[row_idx, col_idx] ^= 1

# Get unique F columns with indices mapping back to F.
uniq_F.clear()
for col_idx in I:
for i in range(q):
Fvec_str[i] = b'1' if F[i, col_idx] else b'0'

it = uniq_F.find(Fvec_str)
if it != uniq_F.end():
deref(it).second.push_back(col_idx)
else:
uniq_F[Fvec_str] = vector[int](1, col_idx)

# Solve grouped normal equations
it = uniq_F.begin()
while it != uniq_F.end():
Fvec.clear()
Gvec.clear()
for i in range(q):
if deref(it).first[i] == b'1':
Fvec.push_back(i)
else:
Gvec.push_back(i)

fvsize = Fvec.size()
gvsize = Gvec.size()
uqsize = deref(it).second.size()

if fvsize > 0:
# CTC_L = CTC[Fvec, Fvec]
for i in range(fvsize):
for j in range(fvsize):
CTC_L[i, j] = CTC[Fvec[i], Fvec[j]]
L_tensor = torch.cholesky(CTC_L_tensor[0:fvsize, 0:fvsize]) if device_type == 'cpu' else torch.cholesky(CTC_L_tensor[0:fvsize, 0:fvsize].cuda())
# CTB_L = CTB[Fvec, Ii]
for i in range(fvsize):
for j in range(uqsize):
CTB_L[i, j] = CTB[Fvec[i], deref(it).second[j]]
x_tensor = torch.cholesky_solve(CTB_L_tensor[0:fvsize, 0:uqsize], L_tensor) if device_type == 'cpu' else torch.cholesky_solve(CTB_L_tensor[0:fvsize, 0:uqsize].cuda(), L_tensor)
x = x_tensor.cpu().numpy()
# clean up
for i in range(fvsize):
for j in range(uqsize):
X[Fvec[i], deref(it).second[j]] = _filter(x[i, j], tol)
Y[Fvec[i], deref(it).second[j]] = 0.0

if gvsize > 0:
if fvsize > 0:
# CTC_L = CTC[~Fvec, Fvec]
for i in range(gvsize):
for j in range(fvsize):
CTC_L[i, j] = CTC[Gvec[i], Fvec[j]]
y_tensor = (CTC_L_tensor[0:gvsize, 0:fvsize] @ x_tensor) if device_type == 'cpu' else (CTC_L_tensor[0:gvsize, 0:fvsize].cuda() @ x_tensor)
y = y_tensor.cpu().numpy()

for i in range(gvsize):
row_idx = Gvec[i]
for j in range(uqsize):
col_idx = deref(it).second[j]
Y[row_idx, col_idx] = _filter(y[i,j] - CTB[row_idx, col_idx], tol)
X[row_idx, col_idx] = 0.0
else:
for i in range(gvsize):
row_idx = Gvec[i]
for j in range(uqsize):
col_idx = deref(it).second[j]
Y[row_idx, col_idx] = _filter(-CTB[row_idx, col_idx], tol)
X[row_idx, col_idx] = 0.0

pinc(it)

size_I = 0
for j in range(I.size()):
pos = I[j]
Vsize[pos] = 0
for i in range(q):
V[i, pos] = ((X[i, pos] < 0.0) & F[i, pos]) | ((Y[i, pos] < 0.0) & (F[i, pos] == 0))
Vsize[pos] += V[i, pos]
if Vsize[pos] > 0:
I[size_I] = pos
size_I += 1
I.resize(size_I)

n_iter += 1

return n_iter if I.size() == 0 else -1
28 changes: 27 additions & 1 deletion nmf/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,27 @@
from .nmf import NMF

#from ._nmf_batch_mu import NMFBatchMU
#from ._nmf_batch_hals import NMFBatchHALS
#from ._nmf_batch_nnls_bpp import NMFBatchNnlsBpp
#from ._nmf_online_mu import NMFOnlineMU
#from ._nmf_online_hals import NMFOnlineHALS
#from ._nmf_online_nnls_bpp import NMFOnlineNnlsBpp

from .nmf import run_nmf, integrative_nmf

#from ._inmf_batch_mu import INMFBatchMU
#from ._inmf_batch_hals import INMFBatchHALS
#from ._inmf_batch_nnls_bpp import INMFBatchNnlsBpp
#from ._inmf_online_mu import INMFOnlineMU
#from ._inmf_online_hals import INMFOnlineHALS
#from ._inmf_online_nnls_bpp import INMFOnlineNnlsBpp

try:
from importlib.metadata import version, PackageNotFoundError
except ImportError: # < Python 3.8: Use backport module
from importlib_metadata import version, PackageNotFoundError

try:
__version__ = version('nmf-torch')
del version
except PackageNotFoundError:
pass
1 change: 1 addition & 0 deletions nmf/cylib/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This folder contains compiled Cython libraries.
6 changes: 6 additions & 0 deletions nmf/inmf_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._inmf_batch_hals import INMFBatchHALS
from ._inmf_batch_mu import INMFBatchMU
from ._inmf_batch_nnls_bpp import INMFBatchNnlsBpp
from ._inmf_online_hals import INMFOnlineHALS
from ._inmf_online_mu import INMFOnlineMU
from ._inmf_online_nnls_bpp import INMFOnlineNnlsBpp
Loading

0 comments on commit 532dc1c

Please sign in to comment.