-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadmm_serial.py
29 lines (24 loc) · 856 Bytes
/
admm_serial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
class ADMM(object):
"""
min f(x) + g(z)
s.t. Ax + Bz = c
"""
def __init__(self, A, b, lamb, rho):
self.A = A
self.b = b
self.lamb = lamb
self.rho = rho
self.x = np.zeros((A.shape[0], 1))
self.z = np.zeros((A.shape[0], 1))
self.nu = np.zeros((A.shape[0], 1))
def update(self):
self.x = np.linalg.inv(self.A.dot(self.A.T) + self.rho).dot(self.A.dot(self.b) + self.rho * self.z - self.nu)
self.z = self.x + self.nu / self.rho - (self.lamb / self.rho) * np.sign(self.z)
self.nu = self.nu + self.rho * (self.x - self.z)
def getparam(self):
return self.x
def get_diff(self):
print(self.x - self.z)
def obj_func(A, x, b, lamb):
return 0.5 * np.linalg.norm(A.T.dot(x) - b) ** 2 + lamb * np.linalg.norm(x, 1)