Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
427 changes: 427 additions & 0 deletions docs/deep_dives/l63_speedup_dirac_vs_enkf.ipynb

Large diffs are not rendered by default.

68 changes: 52 additions & 16 deletions dynestyx/discretizers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Discretization schemes for converting continuous-time state evolution to discrete-time."""

import jax.numpy as jnp
import numpyro.distributions as dist
from jax import vmap

from dynestyx.dynamical_models import (
ContinuousTimeStateEvolution,
Expand All @@ -16,19 +14,57 @@ def __init__(self, cte: ContinuousTimeStateEvolution):
self.cte = cte

def __call__(self, x, u, t_now, t_next):
delta_t = t_next - t_now
drift = self.cte.drift(x, u, t_now)
L_fn = getattr(self.cte, "diffusion_coefficient", None)
if L_fn is None:
raise AttributeError(
"ContinuousTimeStateEvolution must define diffusion_coefficient."
)
L = L_fn(x, u, t_now) if callable(L_fn) else L_fn
Q_fn = getattr(self.cte, "diffusion_covariance", jnp.eye(x.shape[-1]))
Q = Q_fn(x, u, t_now) if callable(Q_fn) else Q_fn
mean = x + drift * delta_t
cov = (L @ Q @ L.T) * delta_t
return dist.MultivariateNormal(loc=mean, covariance_matrix=cov)
"""
Discretize continuous-time state evolution via Euler-Maruyama. (CTSE) -> DTSE.

We step from t_now to t_next for each timepoint provided (optionally just 1 timepoint provided).
The main use case of providing multiple timepoints is when paired with DiracDeltaObservation that
allows temporal independence between observations, which allows us to step through all timepoints at once (creating big speedups).

Args:
x: (dim_state,) or (dim_state, num_timepoints)
u: (dim_control,) or (dim_control, num_timepoints)
t_now: (1,) or (num_timepoints,)
t_next: (1,) or (num_timepoints,)

Returns:
dist: MultivariateNormal distribution
- loc: (dim_state, num_timepoints) or (dim_state)
- covariance_matrix: (dim_state, dim_state, num_timepoints) or (dim_state, dim_state)
"""

squeezed = False
if x.ndim == 1:
squeezed = True
x = x[:, None] # (dim_state, 1) state
if u is not None:
if u.ndim == 1:
u = u[:, None] # (dim_control, 1) control
if t_now.ndim == 0:
t_now = t_now[None] # (1,) timepoint
if t_next.ndim == 0:
t_next = t_next[None] # (1,) timepoint

def _step(_x, _u, _t_now, _t_next):
_dt = _t_next - _t_now
drift = self.cte.drift(_x, _u, _t_now)
x_pred_mean = _x + drift * _dt
L = self.cte.diffusion_coefficient(_x, _u, _t_now)
Q = self.cte.diffusion_covariance(_x, _u, _t_now)
x_pred_cov = L @ Q @ L.T * _dt
return x_pred_mean, x_pred_cov

if u is None:
loc, cov = vmap(_step, in_axes=(1, None, 0, 0))(x, None, t_now, t_next)
else:
loc, cov = vmap(_step, in_axes=(1, 1, 0, 0))(x, u, t_now, t_next)

# If we lifted from unbatched, return unbatched dist shapes
if squeezed:
loc = loc[0]
cov = cov[0]

return dist.MultivariateNormal(loc=loc, covariance_matrix=cov)


def euler_maruyama(cte: ContinuousTimeStateEvolution) -> DiscreteTimeStateEvolution:
Expand Down
34 changes: 28 additions & 6 deletions dynestyx/simulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,20 +212,42 @@ def simulate(
numpyro.sample("x_0", dynamics.initial_condition, obs=obs_values[0])
numpyro.deterministic("y_0", obs_values[0])

x_prev = obs_values[:-1]
x_next = obs_values[1:]
u_prev = ctrl_values[:-1] if ctrl_values is not None else None
# Ensure (T-1, state_dim) so swapaxes to (state_dim, T-1) is valid (state_dim=1 => 1D otherwise).
if obs_values.ndim == 1:
x_prev = obs_values[:-1][:, None]
x_next = obs_values[1:][:, None]
else:
x_prev = obs_values[:-1]
x_next = obs_values[1:]
if ctrl_values is not None:
if ctrl_values.ndim == 1:
u_prev = ctrl_values[:-1][:, None]
else:
u_prev = ctrl_values[:-1]
else:
u_prev = None
Comment thread
DanWaxman marked this conversation as resolved.
t_now = obs_times[:-1]
t_next = obs_times[1:]

# Pass state (and controls) with batch as last axis so drift can use
# naive indexing (x[0], x[1], ...) and discretizer broadcasts correctly.
x_prev_batch_last = jnp.swapaxes(x_prev, 0, 1)
x_next_batch_last = jnp.swapaxes(x_next, 0, 1)
u_prev_batch_last = (
jnp.swapaxes(u_prev, 0, 1) if u_prev is not None else None
)

with numpyro.plate("time", T - 1):
trans = dynamics.state_evolution(
x_prev,
u_prev,
x_prev_batch_last,
u_prev_batch_last,
t_now,
t_next, # type: ignore
)
numpyro.sample("x_next", trans, obs=x_next) # type: ignore
# obs shape must match trans.batch_shape + trans.event_shape: use
# time-first (T-1, state_dim) for e.g. discretizer; batch-last (state_dim, T-1) for scalar.
obs_next = x_next_batch_last if dynamics.state_dim == 1 else x_next
numpyro.sample("x_next", trans, obs=obs_next) # type: ignore

return {
"times": obs_times,
Expand Down
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ nav:
- HMM Inference: tutorials/hmm_inference.ipynb
- SDE with Non-Gaussian Observations: tutorials/sde_non_gaussian_observations.ipynb
- ODE Inference: tutorials/ode_inference.ipynb
- Deep Dives:
- Speedups with full-state low-noise observations: deep_dives/l63_speedup_dirac_vs_enkf.ipynb

theme:
name: "material"
Expand Down
9 changes: 0 additions & 9 deletions tests/test_mcmc_smokes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""

import jax.random as jr
import pytest
from numpyro.infer import MCMC, NUTS, BarkerMH

from tests.fixtures import (
Expand Down Expand Up @@ -72,14 +71,6 @@ def test_discrete_time_l63_auto_mcmc_smoke(
assert "rho" in posterior_samples


# This test is expected to fail currently due to broadcasting issues in the discretizer/simulator interaction.
@pytest.mark.skipif(
True,
reason=(
"Expected to fail currently: exposes broadcasting interaction between "
"Discretizer, DiracIdentityObservation, and DiscreteTimeSimulator."
),
)
def test_discrete_time_l63_auto_dirac_obs_mcmc_smoke(
data_conditioned_discrete_time_l63_auto_dirac_obs, # noqa: F811
) -> None:
Expand Down