Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions src/tinygp/kernels/quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.linalg import block_diag as jsp_block_diag

from tinygp.helpers import JAXArray
from tinygp.kernels.base import Kernel
from tinygp.solvers.quasisep.block import Block
from tinygp.solvers.quasisep.block import Block, ensure_dense
from tinygp.solvers.quasisep.core import DiagQSM, StrictLowerTriQSM, SymmQSM
from tinygp.solvers.quasisep.general import GeneralQSM

Expand Down Expand Up @@ -220,20 +221,39 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:


class Sum(Quasisep):
"""A helper to represent the sum of two quasiseparable kernels"""
"""A helper to represent the sum of two quasiseparable kernels

Args:
kernel1: The first kernel.
kernel2: The second kernel.
use_block: If ``True`` (default), use :class:`Block` diagonal matrices
for the transition matrices, design matrices, and stationary
covariance. If ``False``, use dense ``block_diag`` representations
instead, which avoids compatibility issues with some operations
(e.g. banded noise, product kernels) at a small performance cost
for the state-space matrices.
"""

kernel1: Quasisep
kernel2: Quasisep
use_block: bool = eqx.field(static=True, default=True)

def coord_to_sortable(self, X: JAXArray) -> JAXArray:
"""We assume that both kernels use the same coordinates"""
return self.kernel1.coord_to_sortable(X)

def _block_or_dense(self, m1: JAXArray, m2: JAXArray) -> JAXArray:
if self.use_block:
return Block(m1, m2)
return jsp_block_diag(m1, m2)

def design_matrix(self) -> JAXArray:
return Block(self.kernel1.design_matrix(), self.kernel2.design_matrix())
return self._block_or_dense(
self.kernel1.design_matrix(), self.kernel2.design_matrix()
)

def stationary_covariance(self) -> JAXArray:
return Block(
return self._block_or_dense(
self.kernel1.stationary_covariance(),
self.kernel2.stationary_covariance(),
)
Expand All @@ -247,7 +267,7 @@ def observation_model(self, X: JAXArray) -> JAXArray:
)

def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
return Block(
return self._block_or_dense(
self.kernel1.transition_matrix(X1, X2),
self.kernel2.transition_matrix(X1, X2),
)
Expand Down Expand Up @@ -632,6 +652,8 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:


def _prod_helper(a1: JAXArray, a2: JAXArray) -> JAXArray:
a1 = ensure_dense(a1)
a2 = ensure_dense(a2)
i, j = np.meshgrid(np.arange(a1.shape[0]), np.arange(a2.shape[0]))
i = i.flatten()
j = j.flatten()
Expand Down
7 changes: 7 additions & 0 deletions src/tinygp/solvers/quasisep/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
from tinygp.helpers import JAXArray


def ensure_dense(x: Any) -> Any:
"""Convert a Block to a dense array, passing through non-Block inputs."""
if isinstance(x, Block):
return x.to_dense()
return x


class Block(eqx.Module):
blocks: tuple[Any, ...]
__array_priority__ = 1999
Expand Down
9 changes: 7 additions & 2 deletions src/tinygp/solvers/quasisep/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jax.scipy.linalg import block_diag

from tinygp.helpers import JAXArray
from tinygp.solvers.quasisep.block import ensure_dense


def handle_matvec_shapes(
Expand Down Expand Up @@ -213,20 +214,24 @@ def impl(
return StrictLowerTriQSM(
p=jnp.concatenate((p1, p2)),
q=jnp.concatenate((q1, q2)),
a=block_diag(a1, a2),
a=block_diag(ensure_dense(a1), ensure_dense(a2)),
)

return impl(self, other)

def self_mul(self, other: StrictLowerTriQSM) -> StrictLowerTriQSM:
"""The elementwise product of two :class:`StrictLowerTriQSM` matrices"""
# vmap is needed because a batched Block has 3D block arrays that
# block_diag (used by to_dense) cannot handle without unbatching.
self_a = jax.vmap(ensure_dense)(self.a)
other_a = jax.vmap(ensure_dense)(other.a)
i, j = np.meshgrid(np.arange(self.p.shape[1]), np.arange(other.p.shape[1]))
i = i.flatten()
j = j.flatten()
return StrictLowerTriQSM(
p=self.p[:, i] * other.p[:, j],
q=self.q[:, i] * other.q[:, j],
a=self.a[:, i[:, None], i[None, :]] * other.a[:, j[:, None], j[None, :]],
a=self_a[:, i[:, None], i[None, :]] * other_a[:, j[:, None], j[None, :]],
)

def __neg__(self) -> StrictLowerTriQSM:
Expand Down
29 changes: 15 additions & 14 deletions src/tinygp/solvers/quasisep/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax.numpy as jnp

from tinygp.helpers import JAXArray
from tinygp.solvers.quasisep.block import ensure_dense
from tinygp.solvers.quasisep.core import (
QSM,
DiagQSM,
Expand Down Expand Up @@ -145,15 +146,15 @@ def impl(
u += [upper_b.p] if upper_b is not None else []

if lower_a is not None and lower_b is not None:
la_a = ensure_dense(lower_a.a)
lb_a = ensure_dense(lower_b.a)
ell = jnp.concatenate(
(
jnp.concatenate(
(lower_a.a, jnp.outer(lower_a.q, lower_b.p)), axis=-1
),
jnp.concatenate((la_a, jnp.outer(lower_a.q, lower_b.p)), axis=-1),
jnp.concatenate(
(
jnp.zeros((lower_b.a.shape[0], lower_a.a.shape[0])),
lower_b.a,
jnp.zeros((lb_a.shape[0], la_a.shape[0])),
lb_a,
),
axis=-1,
),
Expand All @@ -162,33 +163,33 @@ def impl(
)
else:
ell = (
lower_a.a
ensure_dense(lower_a.a)
if lower_a is not None
else lower_b.a if lower_b is not None else None
else ensure_dense(lower_b.a) if lower_b is not None else None
)

if upper_a is not None and upper_b is not None:
ua_a = ensure_dense(upper_a.a)
ub_a = ensure_dense(upper_b.a)
delta = jnp.concatenate(
(
jnp.concatenate(
(
upper_a.a,
jnp.zeros((upper_a.a.shape[0], upper_b.a.shape[0])),
ua_a,
jnp.zeros((ua_a.shape[0], ub_a.shape[0])),
),
axis=-1,
),
jnp.concatenate(
(jnp.outer(upper_b.q, upper_a.p), upper_b.a), axis=-1
),
jnp.concatenate((jnp.outer(upper_b.q, upper_a.p), ub_a), axis=-1),
),
axis=0,
)

else:
delta = (
upper_a.a
ensure_dense(upper_a.a)
if upper_a is not None
else upper_b.a if upper_b is not None else None
else ensure_dense(upper_b.a) if upper_b is not None else None
)

return (
Expand Down
39 changes: 39 additions & 0 deletions tests/test_kernels/test_quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from tinygp import GaussianProcess
from tinygp.kernels import quasisep
from tinygp.noise import Banded
from tinygp.test_utils import assert_allclose


Expand Down Expand Up @@ -157,3 +158,41 @@ def test_carma_quads():
assert_allclose(carma31.arroots, carma31_quads.arroots)
assert_allclose(carma31.acf, carma31_quads.acf)
assert_allclose(carma31.obsmodel, carma31_quads.obsmodel)


def test_sum_kernel_with_banded_noise(data):
x, y, _ = data
N = len(x)
k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0)
banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1)))
gp = GaussianProcess(k, x, noise=banded)
assert jnp.isfinite(gp.log_probability(y))
lp, cond_gp = gp.condition(y)
assert jnp.isfinite(lp)


def test_product_of_sum_kernel(data):
x, y, _ = data
k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0)
gp = GaussianProcess(k, x, diag=jnp.ones(len(x)))
assert jnp.isfinite(gp.log_probability(y))
assert_allclose(k.to_symm_qsm(x).to_dense(), k(x, x))


def test_sum_times_sum_kernel(data):
x, y, _ = data
k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * (
quasisep.Exp(0.5) + quasisep.Matern32(1.0)
)
gp = GaussianProcess(k, x, diag=jnp.ones(len(x)))
assert jnp.isfinite(gp.log_probability(y))


def test_sum_kernel_use_block_false(data):
x, y, _ = data
N = len(x)
k_block = quasisep.Cosine(1.0) + quasisep.Cosine(2.0)
k_dense = quasisep.Sum(quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False)
gp_block = GaussianProcess(k_block, x, diag=0.1 * jnp.ones(N))
gp_dense = GaussianProcess(k_dense, x, diag=0.1 * jnp.ones(N))
assert_allclose(gp_block.log_probability(y), gp_dense.log_probability(y))