Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
name: ringdown
channels:
- conda-forge
- defaults
dependencies:
- python=3.12
- numpy=1.26.4
- lalsuite=7.23
- h5py=3.11
- arviz=0.19
- h5py=3.12
- arviz=0.20
- pandas=2.2
- qnm=0.4.3
- seaborn=0.13
Expand Down
3 changes: 1 addition & 2 deletions ringdown/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,8 +870,7 @@ def run(self,
if 'a_scale_max' in ms:
ms['a_scale_max'] = ms['a_scale_max'] / self.strain_scale
logging.info('making model')
model = make_model(self.modes.value, prior=prior, predictive=False,
store_h_det=False, store_h_det_mode=False, **ms)
model = make_model(self.modes.value, prior=prior, **ms)
if return_model:
return model

Expand Down
261 changes: 157 additions & 104 deletions ringdown/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import jax.numpy as jnp
import jax.scipy as jsp
from jax import lax

import numpyro
import numpyro.distributions as dist
Expand Down Expand Up @@ -185,7 +186,7 @@ def chi_factors(chi, coeffs):
log_sqrt_1m_chi2_4 = log_sqrt_1m_chi2_2*log_sqrt_1m_chi2_2
log_sqrt_1m_chi2_5 = log_sqrt_1m_chi2_3*log_sqrt_1m_chi2_2
log_sqrt_1m_chi2_6 = log_sqrt_1m_chi2_3*log_sqrt_1m_chi2_3

v = jnp.stack([
1.,
log_1m_chi,
Expand All @@ -199,7 +200,7 @@ def chi_factors(chi, coeffs):
log_sqrt_1m_chi2_5,
log_sqrt_1m_chi2_6
])

return jnp.dot(coeffs, v)


Expand Down Expand Up @@ -289,6 +290,146 @@ def get_quad_derived_quantities(nmodes, design_matrices, quads, a_scale, YpYc,
return a, h_det


def make_likelihood_update(dms, ls, strains) -> tuple:
"""A marginalized-likelihood update for a single detector,
to be used with 'jax.lax.scan' or 'numpyro.scan'.

The meaning of the arguments is:

mu : array_like
The prior mean; shape (nquads*nmode,).
Lambda_inv : array_like
The prior precision; shape (nquads*nmode, nquads*nmode).
Lambda_inv_chol : array_like
The Cholesky factor of the prior precision; shape (nquads*nmode,
nquads*nmode).
M : array_like
The design matrix; shape (ntime, nquads*nmode).
L : array_like
The noise covariance matrix; shape (ntime, ntime).
y : array_like
The strain data; shape (ntime,).

Arguments
---------
mu_Lambda_inv_Lambda_inv_chol : tuple
A tuple of the prior mean, precision, and Cholesky factor of the
precision.
M_L_y : tuple
A tuple of the design matrix, the noise covariance matrix, and the
strain data.

Returns
-------
mu_Lambda_inv_Lambda_inv_chol : tuple
The updated prior mean, precision, and Cholesky factor of the
precision.
"""
def likelihood_update(mu_Lambda_inv_Lambda_inv_chol, i):
# unpack (carry) variables and (current state) parameters
mu, Lambda_inv, Lambda_inv_chol = mu_Lambda_inv_Lambda_inv_chol
M = dms[i]
L = ls[i]
y = strains[i]

# M acts as a coordinate transformation matrix, taking us
# from the space of quadratures to the space of the data ,
# while M^T takes us from data space to quadrature space
# (M is ntime x nquads*nmode)

# L whitens the noise in the detector, taking it from
# N(0, C) to N(0, I) (L is ntime x ntime)

# we can use M and L to compute the precision (A_inv) of
# the marginal posterior on the quadratures (conditioned on
# the current data and nonlinear parameters), which is just
# the sum of the prior precision (Lambda_inv) and the
# likelihood precision (M^T C^-1 M):
# A_inv = Lambda_inv + M^T C^-1 M
# so that A and A_inv are (nquads*nmode, nquads*nmode)
A_inv = Lambda_inv + \
jnp.dot(M.T, jsp.linalg.cho_solve((L, True), M))
A_inv_chol = jsp.linalg.cholesky(A_inv, lower=True)

# we can also compute the marginal-posterior mean (a),
# which is the precision-weighted sum of the prior mean
# (mu) and the likelihood mean (M^T C^-1 y):
# a = A_inv (Lambda_inv mu + M^T C^-1 y)
# so that a is (nquads*nmode,)
a = jsp.linalg.cho_solve(
(A_inv_chol, True), jnp.dot(Lambda_inv, mu) +
jnp.dot(M.T, jsp.linalg.cho_solve((L, True), y)))

# the mean (b) of the marginal likelihood p(y|b, B),
# i.e., the likelihood obtained after integrating out
# the quadratures, is simply the value of the strain y
# corresponding to the mean quadratures, i.e., mu after
# a coordinate transformation:
# b = M mu
# so that b is (ntime,)
b = jnp.dot(M, mu)

# the (co)variance of the marginal likelihood (B) is the
# sum of the variance from the noise (C) and the variance
# from the quadrature prior (Lambda):
# B = C + M Lambda M^T
# this is (ntime, ntime), which is large; but, to compute
# the marginal likelihood, we need the inverse covariance
# B^-1, so we can use the Woodbury identity to write:
# B^-1 = C^-1 - C^-1 M (Lambda^-1 + M^T C^-1 M)^-1 M^T C^-1
# = C^-1 - C^-1 M A M^T C^-1
# where A = A_inv^-1 per the above; this way we avoid
# inverting the large matrix B directly and take advantage
# of the precomputed Cholesky factor L to get C^-1

# with the residual r = y - b, the marginal log-likelihood
# becomes
# logl = -0.5 r^T B^-1 r - 0.5 log |2pi B|
# where |2pi B| is the determinant of 2pi*B and we can
# ignore the 2pi factor since it introduces a term like
# - 0.5*ntime*log(2pi), which is constant
r = y - b
Cinv_r = jsp.linalg.cho_solve((L, True), r)

M_A_Mt_Cinv_r = jnp.dot(M, jsp.linalg.cho_solve(
(A_inv_chol, True), jnp.dot(M.T, Cinv_r)))

Cinv_M_A_Mt_Cinv_r = \
jsp.linalg.cho_solve((L, True), M_A_Mt_Cinv_r)

# now all we have left to compute is the log determinant
# term, 0.5*log|B|; from the Gaussian refactorization, we
# have that
# |Lambda| |C| = |A| |B|
# and therefore
# log|B| = log|C| + log|Lambda| - log|A|
# furthermore, since |C| = |L|^2, we can write
# 0.5 log|C| = log|L|
# and |L| is the product of the diagonal entries of L;
# writing similarly for |A| and |Lambda|, we thus have
# that log_sqrt_det_B = 0.5 log|B| is
# (note that |A| = -|A_inv|)
log_sqrt_det_B = \
jnp.sum(jnp.log(jnp.diag(L))) - \
jnp.sum(jnp.log(jnp.diag(Lambda_inv_chol))) + \
jnp.sum(jnp.log(jnp.diag(A_inv_chol)))

# putting it all together we can get the contribution
# to the log likelihood from this detector
logl = -0.5*jnp.dot(r, Cinv_r - Cinv_M_A_Mt_Cinv_r) \
- log_sqrt_det_B

# numpyro.factor(f'logl_{i}', logl)

# update the prior mean and precision for the next detector
mu = a
Lambda_inv = A_inv
Lambda_inv_chol = A_inv_chol

return (mu, Lambda_inv, Lambda_inv_chol), logl
return likelihood_update


def make_model(modes: int | list[(int, int, int, int)],
a_scale_max: float,
marginalized: bool = True,
Expand All @@ -311,9 +452,9 @@ def make_model(modes: int | list[(int, int, int, int)],
mode_ordering: None | str = None,
single_polarization: bool = False,
prior: bool = False,
predictive: bool = True,
store_h_det: bool = True,
store_h_det_mode: bool = True):
predictive: bool = False,
store_h_det: bool = False,
store_h_det_mode: bool = False):
"""
Arguments
---------
Expand Down Expand Up @@ -652,107 +793,19 @@ def model(times, strains, ls, fps, fcs,
# iterating over all detectors, we have turned the prior into
# the posterior

for i in range(n_det):
# select the design matrix (M), the Cholesky factor (L),
# and the strain (y) for the current detector
# (ndet, ntime, nquads*nmode) => (i, ntime, nquads*nmode)
M = dms[i, :, :]
L = ls[i, :, :]
y = strains[i, :]

# M acts as a coordinate transformation matrix, taking us
# from the space of quadratures to the space of the data ,
# while M^T takes us from data space to quadrature space
# (M is ntime x nquads*nmode)

# L whitens the noise in the detector, taking it from
# N(0, C) to N(0, I) (L is ntime x ntime)

# we can use M and L to compute the precision (A_inv) of
# the marginal posterior on the quadratures (conditioned on
# the current data and nonlinear parameters), which is just
# the sum of the prior precision (Lambda_inv) and the
# likelihood precision (M^T C^-1 M):
# A_inv = Lambda_inv + M^T C^-1 M
# so that A and A_inv are (nquads*nmode, nquads*nmode)
A_inv = Lambda_inv + \
jnp.dot(M.T, jsp.linalg.cho_solve((L, True), M))
A_inv_chol = jsp.linalg.cholesky(A_inv, lower=True)

# we can also compute the marginal-posterior mean (a),
# which is the precision-weighted sum of the prior mean
# (mu) and the likelihood mean (M^T C^-1 y):
# a = A_inv (Lambda_inv mu + M^T C^-1 y)
# so that a is (nquads*nmode,)
a = jsp.linalg.cho_solve(
(A_inv_chol, True), jnp.dot(Lambda_inv, mu) +
jnp.dot(M.T, jsp.linalg.cho_solve((L, True), y)))

# the mean (b) of the marginal likelihood p(y|b, B),
# i.e., the likelihood obtained after integrating out
# the quadratures, is simply the value of the strain y
# corresponding to the mean quadratures, i.e., mu after
# a coordinate transformation:
# b = M mu
# so that b is (ntime,)
b = jnp.dot(M, mu)

# the (co)variance of the marginal likelihood (B) is the
# sum of the variance from the noise (C) and the variance
# from the quadrature prior (Lambda):
# B = C + M Lambda M^T
# this is (ntime, ntime), which is large; but, to compute
# the marginal likelihood, we need the inverse covariance
# B^-1, so we can use the Woodbury identity to write:
# B^-1 = C^-1 - C^-1 M (Lambda^-1 + M^T C^-1 M)^-1 M^T C^-1
# = C^-1 - C^-1 M A M^T C^-1
# where A = A_inv^-1 per the above; this way we avoid
# inverting the large matrix B directly and take advantage
# of the precomputed Cholesky factor L to get C^-1

# with the residual r = y - b, the marginal log-likelihood
# becomes
# logl = -0.5 r^T B^-1 r - 0.5 log |2pi B|
# where |2pi B| is the determinant of 2pi*B and we can
# ignore the 2pi factor since it introduces a term like
# - 0.5*ntime*log(2pi), which is constant
r = y - b
Cinv_r = jsp.linalg.cho_solve((L, True), r)

M_A_Mt_Cinv_r = jnp.dot(M, jsp.linalg.cho_solve(
(A_inv_chol, True), jnp.dot(M.T, Cinv_r)))

Cinv_M_A_Mt_Cinv_r = \
jsp.linalg.cho_solve((L, True), M_A_Mt_Cinv_r)

# now all we have left to compute is the log determinant
# term, 0.5*log|B|; from the Gaussian refactorization, we
# have that
# |Lambda| |C| = |A| |B|
# and therefore
# log|B| = log|C| + log|Lambda| - log|A|
# furthermore, since |C| = |L|^2, we can write
# 0.5 log|C| = log|L|
# and |L| is the product of the diagonal entries of L;
# writing similarly for |A| and |Lambda|, we thus have
# that log_sqrt_det_B = 0.5 log|B| is
# (note that |A| = -|A_inv|)
log_sqrt_det_B = \
jnp.sum(jnp.log(jnp.diag(L))) - \
jnp.sum(jnp.log(jnp.diag(Lambda_inv_chol))) + \
jnp.sum(jnp.log(jnp.diag(A_inv_chol)))

# putting it all together we can get the contribution
# to the log likelihood from this detector
logl = -0.5*jnp.dot(r, Cinv_r - Cinv_M_A_Mt_Cinv_r) \
- log_sqrt_det_B
likelihood_update = make_likelihood_update(dms, ls, strains)

numpyro.factor(f'logl_{i}', logl)
mu_Lambda_inv_Lambda_inv_chol, logls = lax.scan(
likelihood_update,
(mu, Lambda_inv, Lambda_inv_chol),
jnp.arange(n_det),
)
mu, Lambda_inv, Lambda_inv_chol = mu_Lambda_inv_Lambda_inv_chol

# update the prior mean and precision for the next detector
mu = a
Lambda_inv = A_inv
Lambda_inv_chol = A_inv_chol
# add likelihoods to potential
# TODO: check if can use numpyro.control_flow.scan instead
for i, logl in enumerate(logls):
numpyro.factor(f'logl_{i}', logl)

if predictive:
# Generate the actual quadrature amplitudes by taking a draw
Expand Down
18 changes: 10 additions & 8 deletions ringdown/utils/swsh.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
__all__ = ['construct_sYlm', 'calc_YpYc']

import numpy as np
import jax
import jax.numpy as jnp
from scipy.special import factorial as fac
from jax.scipy.special import factorial as fac


def binom_coeff(n, k):
Expand All @@ -21,8 +22,8 @@ def binom_coeff(n, k):

# binomial coefficient is zero if k>n, or generally if above formula
# returns an inf
num[denom == 0] = 0
denom[denom == 0] = 1
num = jnp.where(denom == 0, 0, num)
denom = jnp.where(denom == 0, 1, denom)
return num / denom


Expand Down Expand Up @@ -88,11 +89,12 @@ def ylm(cosi):

ylm = np.sqrt(1/0.159) * prefactor * sin_th_2(cosi)**(2*ell)

summands = [(-1)**r * binom_coeff(ell - s, r)
* binom_coeff(ell + s, r + s - m)
* cot_th_2(cosi)**(2*r + s - m)
for r in rs]
ylm *= jnp.sum(jnp.array(summands), axis=0)
def get_summand(r):
return (-1)**r * binom_coeff(ell - s, r) * \
binom_coeff(ell + s, r + s - m) * \
cot_th_2(cosi)**(2*r + s - m)

ylm *= jnp.sum(jax.vmap(get_summand)(rs), axis=0)
return ylm
return ylm

Expand Down