Skip to content

Commit

Permalink
Added nmf batch/online NNLS BPP
Browse files Browse the repository at this point in the history
  • Loading branch information
Bo Li committed Jun 6, 2021
1 parent b7fbef3 commit 39039fa
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 11 deletions.
2 changes: 2 additions & 0 deletions nmf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from ._nmf_batch_mu import NMFBatchMU
from ._nmf_batch_hals import NMFBatchHALS
from ._nmf_batch_nnls_bpp import NMFBatchNnlsBpp
from ._nmf_online_mu import NMFOnlineMU
from ._nmf_online_hals import NMFOnlineHALS
from ._nmf_online_nnls_bpp import NMFOnlineNnlsBpp

from ._inmf_batch_mu import INMFBatchMU
from ._inmf_batch_hals import INMFBatchHALS
Expand Down
1 change: 1 addition & 0 deletions nmf/_inmf_batch_nnls_bpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._nnls_bpp import nnls_bpp
from typing import List, Union


class INMFBatchNnlsBpp(INMFBatchBase):
def _update_H_V_W(self):
W_numer = torch.zeros_like(self.W)
Expand Down
4 changes: 2 additions & 2 deletions nmf/_nmf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(

def _get_regularization_loss(self, mat, l1_reg, l2_reg):
res = 0.0
if l1_reg > 0:
if l1_reg > 0.0:
res += l1_reg * mat.norm(p=1)
if l2_reg > 0:
if l2_reg > 0.0:
res += l2_reg * mat.norm(p=2)**2 / 2
return res

Expand Down
106 changes: 106 additions & 0 deletions nmf/_nmf_batch_nnls_bpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch

from ._nmf_batch_base import NMFBatchBase
from ._nnls_bpp import nnls_bpp
from typing import Union


class NMFBatchNnlsBpp(NMFBatchBase):
def __init__(
self,
n_components: int,
init,
beta_loss: float,
tol: float,
random_state: int,
alpha_W: float,
l1_ratio_W: float,
alpha_H: float,
l1_ratio_H: float,
fp_precision: Union[str, torch.dtype],
device_type: str,
max_iter: int = 500,
):
assert beta_loss == 2.0 # only work for F norm for now

super().__init__(
n_components=n_components,
init=init,
beta_loss=beta_loss,
tol=tol,
random_state=random_state,
alpha_W=alpha_W,
l1_ratio_W=l1_ratio_W,
alpha_H=alpha_H,
l1_ratio_H=l1_ratio_H,
fp_precision=fp_precision,
device_type=device_type,
max_iter=max_iter,
)

if self._l2_reg_H > 0.0:
self._l2_H_I = torch.eye(self.k, dtype=self._tensor_dtype, device=self._device_type) * self._l2_reg_H
if self._l2_reg_W > 0.0:
self._l2_W_I = torch.eye(self.k, dtype=self._tensor_dtype, device=self._device_type) * self._l2_reg_W


def _get_regularization_loss(self, mat, l1_reg, l2_reg):
res = 0.0
if l1_reg > 0:
dim = 0 if mat.shape[0] == self.k else 1
res += l1_reg * mat.norm(p=1, dim=dim).norm(p=2)**2
if l2_reg > 0:
res += l2_reg * mat.norm(p=2)**2 / 2
return res


def _update_H(self):
if self._l1_reg_H == 0.0 and self._l2_reg_H == 0.0:
n_iter = nnls_bpp(self._WWT, self._XWT.T, self.H.T, self._device_type)
else:
CTC = self._WWT.clone()
if self._l1_reg_H > 0.0:
CTC += 2.0 * self._l1_reg_H
if self._l2_reg_H > 0.0:
CTC += self._l2_H_I
n_iter = nnls_bpp(CTC, self._XWT.T, self.H.T, self._device_type)
# print(f"H n_iter={n_iter}.")
self._HTH = self.H.T @ self.H


def _update_W(self):
HTX = self.H.T @ self.X
if self._l1_reg_W == 0.0 and self._l2_reg_W == 0.0:
n_iter = nnls_bpp(self._HTH, HTX, self.W, self._device_type)
else:
CTC = self._HTH.clone()
if self._l1_reg_W > 0.0:
CTC += 2.0 * self._l1_reg_W
if self._l2_reg_W > 0.0:
CTC += self._l2_W_I
n_iter = nnls_bpp(CTC, HTX, self.W, self._device_type)
# print(f"W n_iter={n_iter}.")
self._WWT = self.W @ self.W.T
self._XWT = self.X @ self.W.T


def fit(self, X):
super().fit(X)

# Batch update.
for i in range(self._max_iter):
self._update_H()
self._update_W()

if (i + 1) % 10 == 0:
self._cur_err = self._loss()
print(f" niter={i+1}, loss={self._cur_err}.")
if self._is_converged(self._prev_err, self._cur_err, self._init_err):
self.num_iters = i + 1
print(f" Converged after {self.num_iters} iteration(s).")
return

self._prev_err = self._cur_err

self.num_iters = self._max_iter
print(f" Not converged after {self.num_iters} iteration(s).")
8 changes: 2 additions & 6 deletions nmf/_nmf_online_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ def __init__(
device_type: str,
max_pass: int = 20,
chunk_size: int = 5000,
chunk_max_iter: int = 200,
h_tol: float = 0.05,
w_tol: float = 0.05,
):
assert beta_loss == 2.0 # only work for F norm for now

super().__init__(
n_components=n_components,
init=init,
Expand All @@ -39,9 +38,6 @@ def __init__(

self._max_pass = max_pass
self._chunk_size = chunk_size
self._chunk_max_iter = chunk_max_iter
self._h_tol = h_tol
self._w_tol = w_tol


def _h_err(self, h, hth, WWT, xWT):
Expand Down
6 changes: 3 additions & 3 deletions nmf/_nmf_online_hals.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def __init__(
device_type=device_type,
max_pass=max_pass,
chunk_size=chunk_size,
chunk_max_iter=chunk_max_iter,
h_tol=h_tol,
w_tol=w_tol,
)

self._chunk_max_iter = chunk_max_iter
self._h_tol = h_tol
self._w_tol = w_tol
self._zero = torch.tensor(0.0, dtype=self._tensor_dtype, device=self._device_type)


Expand Down
40 changes: 40 additions & 0 deletions nmf/_nmf_online_mu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,46 @@


class NMFOnlineMU(NMFOnlineBase):
def __init__(
self,
n_components: int,
init,
beta_loss: float,
tol: float,
random_state: int,
alpha_W: float,
l1_ratio_W: float,
alpha_H: float,
l1_ratio_H: float,
fp_precision: Union[str, torch.dtype],
device_type: str,
max_pass: int = 20,
chunk_size: int = 5000,
chunk_max_iter: int = 200,
h_tol: float = 0.05,
w_tol: float = 0.05,
):
super().__init__(
n_components=n_components,
init=init,
beta_loss=beta_loss,
tol=tol,
random_state=random_state,
alpha_W=alpha_W,
l1_ratio_W=l1_ratio_W,
alpha_H=alpha_H,
l1_ratio_H=l1_ratio_H,
fp_precision=fp_precision,
device_type=device_type,
max_pass=max_pass,
chunk_size=chunk_size,
)

self._chunk_max_iter = chunk_max_iter
self._h_tol = h_tol
self._w_tol = w_tol


def _update_matrix(self, mat, numer, denom):
rates = numer / denom
rates[denom < self._epsilon] = 0.0
Expand Down
Loading

0 comments on commit 39039fa

Please sign in to comment.