Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add surrogate-informed priors #73

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,182 changes: 2,182 additions & 0 deletions docs/examples/imr_injection_w_surrogate.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: ringdown
channels:
- conda-forge
- defaults
dependencies:
- python=3.12
- numpy=1.26.4
Expand Down
63 changes: 63 additions & 0 deletions ringdown/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,69 @@ def set_modes(self, modes: int | list[tuple[int, int, int, int, int]]):
"""
self.modes = indexing.ModeIndexList(modes)

def set_surrogate(self, surrogate):
"""Establish qnm surrogate that can be used to inform qnm amplitude/phase priors
or perform parameter estimation with.

TO DO: Make this so that it loads the surrogate from the surrogate repository.

Arguments
---------
surrogate : TO DO
NR surrogate of qnm amplitudes from (TO DO)
"""
if self.modes.value == 0:
raise ValueError("self.modes needs to be set before setting the surrogate.")

qnms = [(x[0], x[2], x[3], x[4]) for x in self.modes.value]
self.surrogate = lambda x, M=None, dist_mpc=None, inclination=None, phi_ref=None : surrogate(
x, QNMs=qnms, M=M, dist_mpc=dist_mpc, inclination=inclination, phi_ref=phi_ref
)

def set_priors_from_surrogate(self, progenitor_parameters):
"""Compute qnm amplitude/phase priors by applying the surrogate
to a posterior of progenitor parameters.

Progenitor parameters should match the expected structure of the
qnm surrogate, e.g., (q, chi1z, chi2z) or (q, chi1x, chi1y, chi1z, chi2x, chi2y, chi2z)
and then the following extrinsic parameters (M, dist_mpc, inclination, phi_ref).

Arguments:
----------
progenitor_parameters : ndarray
2d array of progenitor parameters with the parameters in axis=1.
"""
idx = progenitor_parameters.shape[1] - 4
if not idx in [3, 7]:
raise ValueError(
f"progenitor_parameters.shape {progenitor_parameters.shape} is not equal to 3 or 7."
)

qnm_amplitudes = []
for parameters in progenitor_parameters:
qnm_amplitudes.append(
list(self.surrogate(
parameters[:idx],
M=parameters[idx],
dist_mpc=parameters[idx + 1],
inclination=parameters[idx + 2],
phi_ref=parameters[idx + 3]
).values())
)
qnm_amplitudes = np.array(qnm_amplitudes)

means_and_stds = jax.numpy.array([
[
np.mean(abs(qnm_amplitudes[:,i])),
np.std(abs(qnm_amplitudes[:,i])),
np.mean(np.angle(qnm_amplitudes[:,i])),
np.std(np.angle(qnm_amplitudes[:,i]))
] for i in range(qnm_amplitudes.shape[1])
])

self.model_settings['marginalized'] = False
self.model_settings['surrogate_means_and_stds'] = means_and_stds

def set_target(self, t0: float | dict | None = None,
ra: float | None = None,
dec: float | None = None, psi: float | None = None,
Expand Down
122 changes: 107 additions & 15 deletions ringdown/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging


def rd_design_matrix(ts, f, gamma, Fp, Fc, Ascales, aligned=False,
def rd_design_matrix(ts, f, gamma, Fp, Fc, Ascales, t_ref=0.0, aligned=False,
YpYc=None, single_polarization=False):
"""Construct the design matrix for a generic ringdown model.

Expand Down Expand Up @@ -136,6 +136,8 @@ def rd_design_matrix(ts, f, gamma, Fp, Fc, Ascales, aligned=False,
The cross polarization coefficients; shape ``(nifo,)``.
Ascales : array_like
The amplitude scales of the damped sinusoids; shape ``(nmode,)``.
t_ref : array_like
The reference time difference between the prior and inferred amplitudes.

Returns
-------
Expand All @@ -145,6 +147,7 @@ def rd_design_matrix(ts, f, gamma, Fp, Fc, Ascales, aligned=False,
# times should be originally shaped (nifo, nt)
# take it to (nifo, nt, 1) where the last dimension is the mode
ts = jnp.atleast_2d(ts)[:, :, jnp.newaxis]
brt_ref = t_ref * jnp.ones_like(ts)

# get number of detectors, times, and modes
nifo = ts.shape[0]
Expand All @@ -157,9 +160,9 @@ def rd_design_matrix(ts, f, gamma, Fp, Fc, Ascales, aligned=False,
Ascales = jnp.reshape(Ascales, (1, 1, nmode))

# ct and st will have shape (1, nt, nmode)
decay = jnp.exp(-gamma*ts)
ct = Ascales * decay * jnp.cos(2*np.pi*f*ts)
st = Ascales * decay * jnp.sin(2*np.pi*f*ts)
decay = jnp.exp(-gamma*(ts - t_ref))
ct = Ascales * decay * jnp.cos(2*np.pi*f*(ts - t_ref))
st = Ascales * decay * jnp.sin(2*np.pi*f*(ts - t_ref))

if single_polarization:
dm = jnp.concatenate((Fp*ct, Fp*st), axis=2)
Expand Down Expand Up @@ -239,7 +242,6 @@ def get_quad_derived_quantities(nmodes, design_matrices, quads, a_scale, YpYc,
if nquads == 2:
ax_unit = quads[:nmodes]
ay_unit = quads[nmodes:]

a_norm = jnp.sqrt(jnp.square(ax_unit) + jnp.square(ay_unit))
a = numpyro.deterministic('a', a_scale * a_norm)
numpyro.deterministic('phi', jnp.arctan2(ay_unit, ax_unit))
Expand Down Expand Up @@ -302,6 +304,8 @@ def get_quad_derived_quantities(nmodes, design_matrices, quads, a_scale, YpYc,
def make_model(modes: int | list[(int, int, int, int)],
a_scale_max: float,
marginalized: bool = True,
surrogate_means_and_stds: float | None = None,
sample_t_ref: bool = False,
m_min: float | None = None,
m_max: float | None = None,
chi_min: float = 0.0,
Expand Down Expand Up @@ -341,6 +345,19 @@ def make_model(modes: int | list[(int, int, int, int)],
Whether or not to marginalize over the quadrature amplitudes
analytically.

surrogate_means_and_stds : array
Array of amplitude and phase means and standard deviations
extracted from a surrogate run on an IMR posterior. Array should be 2d,
with axis=0 being the different QNMs and axis=1 being of size 4
and ordered by mean A, std A, mean phase, std phase.
(default: None, i.e., use normal distributions on A_x and A_y)

sample_t_ref : bool
Whether or not to sample t_ref. This should be used in conjuction
with the surrogate meand and standard deviations, but is likely
not necessary so long as the standard deviations from the surrogate
are large enough to allow for flexible sampling.

m_min : float
The minimum mass of the black hole in solar masses.

Expand Down Expand Up @@ -419,7 +436,7 @@ def make_model(modes: int | list[(int, int, int, int)],
A model function that can be used with `numpyro` to sample from the
posterior distribution of the ringdown parameters.
"""

n_modes = modes if isinstance(modes, int) else len(modes)

# check arguments for free damped sinusoid fits
Expand Down Expand Up @@ -490,11 +507,12 @@ def make_model(modes: int | list[(int, int, int, int)],
swsh = construct_sYlm(-2, mode_array[:, 2], mode_array[:, 3])
else:
swsh = None

def model(times, strains, ls, fps, fcs,
predictive: bool = predictive,
store_h_det: bool = store_h_det,
store_h_det_mode: bool = store_h_det_mode):
store_h_det_mode: bool = store_h_det_mode,
a_scale_max=a_scale_max):
"""The ringdown model.

Arguments
Expand All @@ -514,6 +532,7 @@ def model(times, strains, ls, fps, fcs,
fcs : array_like
The "cross" polarization coefficients for each IFO; length `n_det`.
"""

times, strains, ls, fps, fcs = map(
jnp.array, (times, strains, ls, fps, fcs))

Expand Down Expand Up @@ -636,7 +655,6 @@ def model(times, strains, ls, fps, fcs,
# https://arxiv.org/abs/2005.14199
# for ease of reference: we use the same variable names
# and matrices are capitalized

a_scale = numpyro.sample('a_scale', dist.Uniform(0, a_scale_max),
sample_shape=(n_modes,))
# get design matrices which will have shape
Expand Down Expand Up @@ -800,7 +818,7 @@ def model(times, strains, ls, fps, fcs,
get_quad_derived_quantities(n_modes, dms, quads,
a_scale, YpYc, store_h_det,
store_h_det_mode)
else:
elif surrogate_means_and_stds is None:
a_scales = a_scale_max*jnp.ones(n_modes)
dms = rd_design_matrix(times, f, g, fps, fcs, a_scales,
aligned=swsh, YpYc=YpYc,
Expand All @@ -823,11 +841,18 @@ def model(times, strains, ls, fps, fcs,
quads = jnp.concatenate(
(apx_unit, apy_unit, acx_unit, acy_unit))

a, h_det = get_quad_derived_quantities(n_modes, dms,
quads, a_scale_max, YpYc,
store_h_det,
store_h_det_mode,
compute_h_det=(not prior))
if prior:
get_quad_derived_quantities(n_modes, dms,
quads, a_scale_max, YpYc,
store_h_det,
store_h_det_mode,
compute_h_det=(not prior))
else:
a, h_det = get_quad_derived_quantities(n_modes, dms,
quads, a_scale_max, YpYc,
store_h_det,
store_h_det_mode,
compute_h_det=(not prior))

if flat_amplitude_prior:
# We need a Jacobian that is A^-3 for the generic model
Expand All @@ -846,7 +871,74 @@ def model(times, strains, ls, fps, fcs,
for i, strain in enumerate(strains):
numpyro.sample(f'logl_{i}', dist.MultivariateNormal(
h_det[i, :], scale_tril=ls[i, :, :]), obs=strain)
else:
a_scale_max = 1
a_scales = a_scale_max * jnp.ones(n_modes)

if sample_t_ref:
if n_modes > 1:
t_ref = numpyro.sample(
't_ref', dist.Normal(
0,
0.005
)
)
else:
t_ref = 0.0
else:
t_ref = 0.0

dms = rd_design_matrix(times, f, g, fps, fcs, a_scales,
t_ref=t_ref, aligned=swsh, YpYc=YpYc,
single_polarization=single_polarization)

if swsh or single_polarization:
a = numpyro.sample(
'a_temp', dist.Normal(
0,
1
), sample_shape=(n_modes,)
)
phi = numpyro.sample(
'phi_temp', dist.Normal(
0,
1
), sample_shape=(n_modes,)
)
a = a * surrogate_means_and_stds[:,1] + surrogate_means_and_stds[:,0]
phi = phi * surrogate_means_and_stds[:,3] + surrogate_means_and_stds[:,2]
psi = numpyro.sample(
'psi', dist.Uniform(
0, 2*jnp.pi
)
)
quads = jnp.concatenate(
(
a * jnp.cos(phi) * jnp.cos(2*psi) - a * jnp.sin(phi) * jnp.sin(2*psi),
a * jnp.sin(phi) * jnp.cos(2*psi) + a * jnp.cos(phi) * jnp.sin(2*psi)
)
)
else:
raise ValueError("Not implemented!")

if prior:
get_quad_derived_quantities(n_modes, dms,
quads, a_scale_max, YpYc,
store_h_det,
store_h_det_mode,
compute_h_det=(not prior))
else:
a, h_det = get_quad_derived_quantities(n_modes, dms,
quads, a_scale_max, YpYc,
store_h_det,
store_h_det_mode,
compute_h_det=(not prior))

for i, strain in enumerate(strains):
numpyro.sample(f'logl_{i}', dist.MultivariateNormal(
h_det[i, :], scale_tril=ls[i, :, :]), obs=strain)


return model


Expand Down
Loading