Skip to content

Commit

Permalink
dss
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Apr 5, 2022
1 parent 71861af commit d415fde
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions s4/dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def reciprocal(x):
e = np.exp(x2)
return e * reciprocal(np.sum(e))

def test_softmax():
x = np.array([1,2,3,4])
a = np.exp(x) / np.exp(x).sum()
b = complex_softmax(x)
assert np.isclose(a, b).all()

def dss_kernel(W, Lambda, L, step):
P = (step * Lambda)[:, None] * np.arange(L)
S = jax.vmap(complex_softmax)(P)
Expand All @@ -27,9 +33,8 @@ class DSSLayer(nn.Module):

def setup(self):
# Learned Parameters
self.W = self.param(
"W", s4.lecun_normal(dtype=np.complex64), (1, self.N)
)
self.W = self.param("W", s4.lecun_normal(), (1, self.N, 2))
self.W = self.W[..., 0] + 1j * self.W[..., 1]
self.D = self.param("D", nn.initializers.ones, (1,))
self.step = np.exp(self.param("log_step", s4.log_step_initializer(), (1,)))
self.K = dss_kernel(self.W, self.Lambda, self.l_max, self.step)
Expand Down

0 comments on commit d415fde

Please sign in to comment.