Skip to content

Commit

Permalink
Add option for adaptive step width to SSFM (NVlabs#99)
Browse files Browse the repository at this point in the history
* implement nonlinear-phase rotation

adaptive step width in SSFM as special value for n_ssfm

* fix: graph for adaptive step width

* add tests for SSFM with adaptive step width

* Fixes minor typos in docstrings

---------

Co-authored-by: Jakob Hoydis <[email protected]>
  • Loading branch information
Sebastian Jung and jhoydis authored Feb 20, 2023
1 parent 5e98a6f commit 9e5074a
Show file tree
Hide file tree
Showing 2 changed files with 392 additions and 31 deletions.
116 changes: 87 additions & 29 deletions sionna/channel/optical/fiber.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ class SSFM(Layer):
Fiber length :math:`\ell` in :math:`(L_\text{norm})`.
Defaults to 80.0.
n_ssfm : int
n_ssfm : int | "adaptive"
Number of steps :math:`N_\mathrm{SSFM}`.
Set to "adaptive" to use nonlinear-phase rotation to calculate
the step widths adaptively (maxmimum rotation can be set in phase_inc).
Defaults to 1.
n_sp : float
Expand Down Expand Up @@ -162,6 +164,11 @@ class SSFM(Layer):
with_nonlinearity : bool
Apply Kerr nonlinearity. Defaults to `True`.
phase_inc: float
Maximum nonlinear-phase rotation in rad allowed during simulation.
To be used with ``n_ssfm`` = "adaptive".
Defaults to 1e-4.
swap_memory : bool
Use CPU memory for while loop. Defaults to `True`.
Expand Down Expand Up @@ -197,6 +204,7 @@ def __init__(self,
with_dispersion=True,
with_manakov=False,
with_nonlinearity=True,
phase_inc=1e-4,
swap_memory=True,
dtype=tf.complex64,
**kwargs):
Expand All @@ -213,7 +221,21 @@ def __init__(self,
self._gamma = tf.cast(gamma, dtype=self._rdtype)
self._half_window_length = half_window_length
self._length = tf.cast(length, dtype=self._rdtype)
self._n_ssfm = tf.cast(n_ssfm, dtype=tf.int32)
self._phase_inc = tf.cast(phase_inc, dtype=self._rdtype)

if n_ssfm == "adaptive":
self._n_ssfm = tf.cast(-1, dtype=tf.int32) # adaptive == -1
elif isinstance(n_ssfm, int):
self._n_ssfm = tf.cast(n_ssfm, dtype=tf.int32)
# Precalculate uniform step size
tf.assert_greater(self._n_ssfm, 0)
else:
raise ValueError("Unsupported parameter for n_ssfm. \
Either an integer or 'adaptive'.")

# only used for constant step width -> negative value calculated
# with adaptive step widths can be ignored
self._dz = self._length / tf.cast(self._n_ssfm, dtype=self._rdtype)
self._n_sp = tf.cast(n_sp, dtype=self._rdtype)
self._swap_memory = swap_memory
self._t_norm = tf.cast(t_norm, dtype=self._rdtype)
Expand All @@ -236,10 +258,6 @@ def __init__(self,
if self._with_manakov:
self._p_n_ase = self._p_n_ase / 2.0

# Precalculate uniform step size
tf.assert_greater(self._n_ssfm, 0)
self._dz = self._length / tf.cast(self._n_ssfm, dtype=self._rdtype)

self._window = tf.complex(
tf.signal.hamming_window(
window_length=2*self._half_window_length,
Expand Down Expand Up @@ -322,6 +340,35 @@ def _apply_nonlinear_operator(self, q, dz, zeros):

return q


def _calculate_step_width(self, q, remaining_length):
max_power = tf.math.reduce_max(tf.math.pow(tf.math.abs(q),2.0),axis=None)
# ensure that the exact length is reached in the end
dz = tf.math.minimum(self._phase_inc / self._gamma / max_power,remaining_length)
return dz

def _adaptive_step(self,q, precalculations, remaining_length, step_counter):

(window, _, zeros, f) = precalculations

dz = self._calculate_step_width(q,remaining_length)

# Apply window-function
q = self._apply_window(q, window)
q = self._apply_linear_operator(q, dz, zeros, f) # D
q = self._apply_nonlinear_operator(q, dz, zeros) # N
q = self._apply_noise(q, dz)
remaining_length = remaining_length - dz

precalculations = (window, dz, zeros, f)
step_counter = step_counter + 1
return q, precalculations, remaining_length, step_counter

def _cond_adaptive(self, q, precalculations,remaining_length,step_counter):
# pylint: disable=unused-argument
return tf.greater_equal(remaining_length, 1e-3) # avoid numerical issues for 0


def _apply_window(self, q, window):
return q * window

Expand Down Expand Up @@ -376,31 +423,42 @@ def call(self, inputs):

# All-zero vector
zeros = tf.zeros(input_shape, dtype=self._rdtype)

# Spatial step size
dz = tf.cast(self._dz, dtype=self._rdtype)

dz_half = dz/tf.cast(2.0, self._rdtype)

# SSFM step counter
iterator = tf.constant(0, dtype=tf.int32, name="step_counter")

# Symmetric SSFM
# Start with half linear propagation
x = self._apply_linear_operator(x, dz_half, zeros, f)
# Proceed with N_SSFM-1 steps applying nonlinear and linear operator
x, _, _, _ = tf.while_loop(
self._cond,
self._step,
(x, (window, dz, zeros, f), self._n_ssfm-1, iterator),
swap_memory=self._swap_memory,
parallel_iterations=1
)
# Final nonlinear operator
x = self._apply_nonlinear_operator(x, dz, zeros)
# Final noise application
x = self._apply_noise(x, dz)
# End with half linear propagation
x = self._apply_linear_operator(x, dz_half, zeros, f)
if self._n_ssfm == -1: # adaptive step width

x, _, _, _ = tf.while_loop(
self._cond_adaptive,
self._adaptive_step,
(x, (window, tf.cast(0.,self._rdtype), zeros, f), self._length, iterator),
swap_memory=self._swap_memory,
parallel_iterations=1
)

# constant step size
else:
# Spatial step size
dz = tf.cast(self._dz, dtype=self._rdtype)

dz_half = dz/tf.cast(2.0, self._rdtype)

# Symmetric SSFM
# Start with half linear propagation
x = self._apply_linear_operator(x, dz_half, zeros, f)
# Proceed with N_SSFM-1 steps applying nonlinear and linear operator
x, _, _, _ = tf.while_loop(
self._cond,
self._step,
(x, (window, dz, zeros, f), self._n_ssfm-1, iterator),
swap_memory=self._swap_memory,
parallel_iterations=1
)
# Final nonlinear operator
x = self._apply_nonlinear_operator(x, dz, zeros)
# Final noise application
x = self._apply_noise(x, dz)
# End with half linear propagation
x = self._apply_linear_operator(x, dz_half, zeros, f)

return x
Loading

0 comments on commit 9e5074a

Please sign in to comment.