Skip to content

Commit

Permalink
deprecated block_diag
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Jun 22, 2020
1 parent 8e9e938 commit a417af0
Show file tree
Hide file tree
Showing 8 changed files with 483 additions and 591 deletions.
11 changes: 3 additions & 8 deletions photontorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,11 @@

## PyTorch extensions

# autograd
from .nn.autograd import block_diag

# neural networks
from .nn import nn
from .nn.nn import Buffer
from .nn.nn import BoundedParameter
from .nn.nn import Module

# custom functional additions
from .nn.functional import BERLoss
from .nn.functional import MSELoss
from .nn.functional import BitStreamGenerator
from .nn.nn import BERLoss
from .nn.nn import MSELoss
from .nn.nn import BitStreamGenerator
7 changes: 4 additions & 3 deletions photontorch/networks/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(self):
from ..nn.nn import Buffer
from ..components.component import Component
from ..components.terms import Term
from ..nn.autograd import block_diag
from ..environment import current_environment


Expand Down Expand Up @@ -885,8 +884,10 @@ def set_C(self, C):
Note:
To create the connection matrix, the connection strings are parsed.
"""

C[:] = block_diag(*(comp.C for comp in self.components.values()))
idx = 0
for comp in self.components.values():
comp.set_C(C[idx : idx + comp.num_ports, idx : idx + comp.num_ports])
idx += comp.num_ports

start_idxs = list(
np.cumsum([0] + [comp.num_ports for comp in self.components.values()])[:-1]
Expand Down
1 change: 0 additions & 1 deletion photontorch/networks/rings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ..components.waveguides import Waveguide
from ..components.directionalcouplers import DirectionalCoupler
from ..nn.nn import Buffer, Parameter
from ..nn import block_diag


###################
Expand Down
42 changes: 7 additions & 35 deletions photontorch/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,9 @@
""" Torch Extensions for PhotonTorch
Since PhotonTorch is a photonic simulation framework in the first place,
we require some extra functionalities that PyTorch does not offer out of
the box.
Below you can find a short summary:
* ``block_diag``: a differentiable implementation of a block diagonal matrix
performed over a batch of matrices.
* ``BoundedParameter``: A bounded parameter is a special kind of
``torch.nn.Parameter`` that is bounded between a certain range.
* ``Buffer``: A special kind of tensor that automatically will
be added to the ``._buffers`` attribute of the Module. Buffers are typically
used as parameters of the model that do not require gradients.
* ``Module``: Extends ``torch.nn.Module``, with some extra
features, such as automatically registering a ``Buffer`` in its
``._buffers`` attribute, modified ``.cuda()`` calls and some extra
functionalities.
* ``BitStreamGenerator``: A simple class that generates random bitstreams.
* ``BERLoss``: A Module that calculates the bit error rate between two bitstreams.
* ``MSELoss``: A Module that calculates the mean squared error between two bitstreams.
"""


## custom torch differentiable functions
from .autograd import block_diag


## custom neural network functions [not imported]
# nn.nn
""" neural network (nn) extensions """

# custom functional additions
from .functional import BERLoss
from .functional import MSELoss
from .functional import BitStreamGenerator
from .nn import Module
from .nn import Buffer
from .nn import BERLoss
from .nn import MSELoss
from .nn import BoundedParameter
from .nn import BitStreamGenerator
55 changes: 0 additions & 55 deletions photontorch/nn/autograd.py

This file was deleted.

Loading

0 comments on commit a417af0

Please sign in to comment.