Skip to content

Commit

Permalink
change convergence conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
ddbourgin committed May 10, 2020
1 parent fe5065d commit c70e14c
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions numpy_ml/factorization/factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,18 @@ def fit(self, X, W=None, H=None, n_initializations=10, verbose=False):

def _fit(self, X, W, H, verbose):
self._init_factor_matrices(X, W, H)
prev_loss = loss = np.inf
W, H = self.W, self.H

for i in range(self.max_iter):
prev_loss = loss
W = self._update_factor(X, H.T)
H = self._update_factor(X.T, W).T

loss = self._loss(X, W @ H)

if verbose:
print("[Iter {}] Loss: {:.6f}".format(i + 1, loss))
print("[Iter {}] Loss: {:.8f}".format(i + 1, loss))

if (prev_loss - loss) <= self.tol:
if loss <= self.tol:
break

return W, H, loss
Expand Down Expand Up @@ -252,8 +250,7 @@ def _loss(self, X, Xhat):

def _update_H(self, X, W, H):
"""Perform the fast HALS update for H"""
# eps = np.finfo(float).eps
eps = 1e-16
eps = np.finfo(float).eps
XtW = X.T @ W # dim: (M, K)
WtW = W.T @ W # dim: (K, K)

Expand All @@ -264,7 +261,7 @@ def _update_H(self, X, W, H):

def _update_W(self, X, W, H):
"""Perform the fast HALS update for W"""
eps = 1e-16 # np.finfo(float).eps
eps = np.finfo(float).eps
XHt = X @ H.T # dim: (N, K)
HHt = H @ H.T # dim: (K, K)

Expand Down Expand Up @@ -360,16 +357,14 @@ def _fit(self, X, W, H, verbose):
self._init_factor_matrices(X, W, H)

W, H = self.W, self.H
prev_loss = loss = np.inf
for i in range(self.max_iter):
prev_loss = loss
H = self._update_H(X, W, H)
W = self._update_W(X, W, H)
loss = self._loss(X, W @ H)

if verbose:
print("[Iter {}] Loss: {:.4f}".format(i + 1, loss))
print("[Iter {}] Loss: {:.8f}".format(i + 1, loss))

if (prev_loss - loss) <= self.tol:
if loss <= self.tol:
break
return W, H, loss

0 comments on commit c70e14c

Please sign in to comment.