Skip to content

Commit

Permalink
deprecated component get_* methods
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Jun 22, 2020
1 parent 75dfe7d commit fb25a09
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 143 deletions.
109 changes: 21 additions & 88 deletions photontorch/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Component(Module):
it should not be used directly.
To define your own component, subclass Component and overwrite the
__init__ and the get_* methods.
__init__ and the set_* methods.
"""

num_ports = 0
Expand Down Expand Up @@ -57,10 +57,22 @@ def __init__(self, name=None):

def _set_buffers(self):
""" create all buffers for the component """
self.C = Buffer(self.get_C())
self.sources_at = Buffer(self.get_sources_at())
self.detectors_at = Buffer(self.get_detectors_at())
self.actions_at = Buffer(self.get_actions_at())
self.C = Buffer(
torch.zeros((self.num_ports, self.num_ports), device=self.device)
)
self.sources_at = Buffer(
torch.zeros(self.num_ports, device=self.device, dtype=torch.bool)
)
self.detectors_at = Buffer(
torch.zeros(self.num_ports, device=self.device, dtype=torch.bool)
)
self.actions_at = Buffer(
torch.zeros(self.num_ports, device=self.device, dtype=torch.bool)
)
self.set_C(self.C.data)
self.set_sources_at(self.sources_at.data)
self.set_detectors_at(self.detectors_at)
self.set_actions_at(self.actions_at)
self.free_ports_at = Buffer(((self.C.sum(0) > 0) | (self.C.sum(1) > 0)).ne(1))
self.terminated = bool(self.free_ports_at.any().ne(1).item())
self.num_sources = int(self.sources_at.sum())
Expand All @@ -69,11 +81,11 @@ def _set_buffers(self):
self.num_free_ports = int(self.free_ports_at.sum())

if (self.sources_at & self.detectors_at).any():
raise ValueError("Source ports and Detector ports cannot be combined.")
raise ValueError("source ports and detector ports cannot be combined.")
if (self.sources_at & self.actions_at).any():
raise ValueError("Source ports and Active ports cannot be combined.")
raise ValueError("source ports and active ports cannot be combined.")
if (self.detectors_at & self.actions_at).any():
raise ValueError("Detector ports and Active ports cannot be combined.")
raise ValueError("detector ports and active ports cannot be combined.")

if self.actions_at.any() and not (
self.actions_at.all() or isinstance(self, Network)
Expand Down Expand Up @@ -182,86 +194,7 @@ def initialize(self):
"""
self._env = env = current_environment()

self.zero_grad()

self.delays = self.get_delays()
self.S = self.get_S()

return self # return the initialized component, so operations can be chained

def get_S(self):
""" get the scattering matrix of the component
Returns:
Tensor[2, #wavelengths, #ports, #ports]: the scattering
matrix of the component (defined for each
wavelength of the simulation). The first dimension of size two
denotes the stacked real and imaginary part.
"""
S = torch.zeros(
(2, self.env.num_wl, self.num_ports, self.num_ports), device=self.device,
)
self.set_S(S)
return S

def get_C(self):
""" get the connection matrix of the component.
Returns:
Tensor[2, #ports, #ports]: the connection matrix for the
component. The first dimension of size two denotes the stacked
real and imaginary part.
"""
C = torch.zeros((self.num_ports, self.num_ports), device=self.device)
self.set_C(C)
return C

def get_delays(self):
""" Get the delays introduced by the component.
Returns:
Tensor[#ports]: the delay tensor for the component.
The delay tensor signifies the delay each port of the component
introduces.
"""
delays = torch.zeros(self.num_ports, device=self.device)
self.set_delays(delays)
return delays

def get_sources_at(self):
""" Get the locations of the sources in the component.
Returns:
Tensor[#ports]: the boolean tensor for the component
which signifies which ports of the component act as a source.
"""
sources_at = torch.zeros(self.num_ports, device=self.device, dtype=torch.bool)
self.set_sources_at(sources_at)
return sources_at

def get_detectors_at(self):
""" Get the locations of the detectors in the component.
Returns:
Tensor[#ports]: the boolean tensor for the component
which signifies which ports of the component act as a detector.
"""
detectors_at = torch.zeros(self.num_ports, device=self.device, dtype=torch.bool)
self.set_detectors_at(detectors_at)
return detectors_at

def get_actions_at(self):
""" Get the locations of the active nodes in the component.
Returns:
Tensor[#ports]: the boolean tensor for the component
which signifies which ports of the component act actively.
"""
actions_at = torch.zeros(self.num_ports, device=self.device, dtype=torch.bool)
self.set_actions_at(actions_at)
return actions_at
return self

def __repr__(self):
""" String Representation of the component """
Expand Down
3 changes: 1 addition & 2 deletions photontorch/components/mzis.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
def set_delays(self, delays):
delays[:] = self.ng * self.length / self.env.c

def get_S(self):
def set_S(self, S):
wls = torch.tensor(self.env.wl, dtype=torch.float64, device=self.device)

# neff depends on the wavelength:
Expand All @@ -98,7 +98,6 @@ def get_S(self):
cos_theta = torch.cos(self.theta).to(torch.get_default_dtype())
sin_theta = torch.sin(self.theta).to(torch.get_default_dtype())
# scattering matrix
S = torch.zeros((2, self.env.num_wl, 4, 4), device=self.device)
S[0, :, 0, 1] = S[0, :, 1, 0] = cos_phi1 * cos_theta
S[1, :, 0, 1] = S[1, :, 1, 0] = sin_phi1 * cos_theta
S[0, :, 0, 2] = S[0, :, 2, 0] = cos_phi1 * sin_theta
Expand Down
101 changes: 55 additions & 46 deletions photontorch/networks/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,19 +311,11 @@ def _conn_sort_key(conn):
self.components = OrderedDict()
for name in self._get_used_component_names():
self.components[name] = all_components[name]
self.components[name].name = name

# set buffers
o = self.order = self._get_port_order()
self.C = Buffer(self.get_C()[o, :][:, o])
self.sources_at = Buffer(self.get_sources_at()[o])
self.detectors_at = Buffer(self.get_detectors_at()[o])
self.actions_at = Buffer(self.get_actions_at()[o])
self.free_ports_at = Buffer(((self.C.sum(0) > 0) | (self.C.sum(1) > 0)).ne(1))
self.terminated = bool(self.free_ports_at.any().ne(1).item())
self.num_sources = int(self.sources_at.sum())
self.num_detectors = int(self.detectors_at.sum())
self.num_actions = int(self.actions_at.sum())
self.num_free_ports = int(self.free_ports_at.sum())
o = self._port_order = self._get_port_order()
super(Network, self)._set_buffers()

# terminate a network
def terminate(self, term=None, name=None):
Expand Down Expand Up @@ -422,29 +414,29 @@ def initialize(self):
reduced matrices without needing the response of the network to an
input signal.
"""
self.zero_grad()

## get current environment
self._env = env = current_environment()

## begin initialization:
self.initialized = False
env = current_environment()

## Initialize components in the network
for comp in self.components.values():
comp.initialize()

## Initialize network
super(Network, self).initialize()
self.S = torch.zeros(
(2, env.num_wl, self.num_ports, self.num_ports), device=self.device
)
self.delays = torch.zeros(self.num_ports, device=self.device)

self.delays = self.delays[self.order]
self.S = self.S[:, :, self.order, :][:, :, :, self.order]
self.set_S(self.S)
self.set_delays(self.delays)

## delays
# delays can be turned off for frequency domain calculations
delays_in_seconds = self.delays * float(not self.env.freqdomain)
delays_in_seconds = self.delays * float(not env.freqdomain)
# resulting delays in terms of the simulation timestep:
if self.env.dt is not None:
delays_in_timesteps = (delays_in_seconds / self.env.dt + 0.5).long()
if env.dt is not None:
delays_in_timesteps = (delays_in_seconds / env.dt + 0.5).long()
else:
delays_in_timesteps = torch.zeros_like(delays_in_seconds).long()

Expand Down Expand Up @@ -520,7 +512,7 @@ def initialize(self):
# We do this in 5 steps:

# 1. Calculation of the helper matrix P = I - Cmlml@Smlml
ones = torch.ones((self.env.num_wl, 1, 1), device=self.device)
ones = torch.ones((env.num_wl, 1, 1), device=self.device)
rP = ones * torch.eye(self.nml, device=self.device)[None, :, :]
rCmlml, _ = torch.broadcast_tensors(rCmlml[None], rSmlml)
rP = rP - (rCmlml).bmm(rSmlml)
Expand Down Expand Up @@ -560,9 +552,8 @@ def initialize(self):
)
self.buffermask[:, delays_in_timesteps[mc], :, range(self.nmc), :] = 1.0

self.initialized = True

# finish initialization:
self._env = env
return self

def _simulation_buffer(self, num_batches):
Expand Down Expand Up @@ -845,30 +836,48 @@ def action(self, t, x_in, x_out):
)
idx += comp.num_ports

def get_delays(self):
""" get all the delays in the network """
return torch.cat([comp.delays for comp in self.components.values()])
def set_S(self, S):
""" get the combined S-matrix of all the components in the network """
idx = 0
for comp in self.components.values():
comp.set_S(S[:, :, idx : idx + comp.num_ports, idx : idx + comp.num_ports])
idx += comp.num_ports
S[:] = S[:, :, self._port_order, :][:, :, :, self._port_order]

def get_detectors_at(self):
""" get the locations of the detectors in the network """
return torch.cat([comp.detectors_at for comp in self.components.values()])
def set_delays(self, delays):
""" set all the delays in the network """
idx = 0
for comp in self.components.values():
comp.set_delays(delays[idx : idx + comp.num_ports])
idx += comp.num_ports
delays[:] = delays[self._port_order]

def get_sources_at(self):
""" get the locations of the sources in the network """
return torch.cat([comp.sources_at for comp in self.components.values()])
def set_detectors_at(self, detectors_at):
""" set the locations of the detectors in the network """
idx = 0
for comp in self.components.values():
comp.set_detectors_at(detectors_at[idx : idx + comp.num_ports])
idx += comp.num_ports
detectors_at[:] = detectors_at[self._port_order]

def get_actions_at(self):
""" get the locations of the functions in the network """
return torch.cat([comp.actions_at for comp in self.components.values()])
def set_sources_at(self, sources_at):
""" set the locations of the sources in the network """
idx = 0
for comp in self.components.values():
comp.set_sources_at(sources_at[idx : idx + comp.num_ports])
idx += comp.num_ports
sources_at[:] = sources_at[self._port_order]

def get_S(self):
""" get the combined S-matrix of all the components in the network """
rS = block_diag(*(comp.S[0] for comp in self.components.values()))
iS = block_diag(*(comp.S[1] for comp in self.components.values()))
return torch.stack([rS, iS])
def set_actions_at(self, actions_at):
""" set the locations of the functions in the network """
idx = 0
for comp in self.components.values():
comp.set_actions_at(actions_at[idx : idx + comp.num_ports])
idx += comp.num_ports
actions_at[:] = actions_at[self._port_order]

def get_C(self):
""" get the combined connection matrix of all the components in the network
def set_C(self, C):
""" set the combined connection matrix of all the components in the network
Returns:
binary tensor with only 1's and 0's.
Expand All @@ -877,7 +886,7 @@ def get_C(self):
To create the connection matrix, the connection strings are parsed.
"""

C = block_diag(*(comp.C for comp in self.components.values()))
C[:] = block_diag(*(comp.C for comp in self.components.values()))

start_idxs = list(
np.cumsum([0] + [comp.num_ports for comp in self.components.values()])[:-1]
Expand Down Expand Up @@ -911,7 +920,7 @@ def parse_connection(conn):
if i is not None:
C[i, j] = C[j, i] = 1.0

return C
C[:] = C[self._port_order, :][:, self._port_order]

def _get_port_order(self):
""" Get the reordering indices for the ports of the network """
Expand Down
9 changes: 5 additions & 4 deletions tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

def test_initialization_with_sources_detectors_at_same_port(tenv):
class WrongTerm(pt.Term):
def get_sources_at(self):
return torch.ones(1, dtype=torch.bool, device=self.device)
def set_sources_at(self, sources_at):
sources_at[0] = True

def get_detectors_at(self):
return torch.ones(1, dtype=torch.bool, device=self.device)
def set_detectors_at(self, detectors_at):
detectors_at[0] = True

with pytest.raises(ValueError):
wt = WrongTerm()
assert wt.sources_at.any()
with tenv:
wt.initialize()

Expand Down
6 changes: 3 additions & 3 deletions tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ def test_cpu(nw):

def test_reinitialize(nw, tenv):
with tenv:
nw.initialized = False # fake the fact that the network is uninitialized
nw._env = None
nw.initialize()
assert nw.initialized
assert nw.env is tenv


def test_initialize_on_unterminated_network(unw, tenv):
with tenv:
unw.initialize()
assert not unw.initialized
assert unw._env is None


def test_initializion_with_too_big_simulation_timestep(nw, tenv):
Expand Down

0 comments on commit fb25a09

Please sign in to comment.