Skip to content
Merged
50 changes: 28 additions & 22 deletions docs/deep_dives/l63_speedup_dirac_vs_enkf.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,19 @@
"from numpyro.infer import MCMC, NUTS, Predictive\n",
"\n",
"import dynestyx as dsx\n",
"from dynestyx.dynamical_models import DynamicalModel, ContinuousTimeStateEvolution\n",
"from dynestyx.filters import FilterBasedMarginalLogLikelihood\n",
"from dynestyx.handlers import Condition, Discretizer\n",
"from dynestyx.ops import Context, Trajectory\n",
"from dynestyx.observations import DiracIdentityObservation, LinearGaussianObservation\n",
"from dynestyx.simulators import DiscreteTimeSimulator, SDESimulator"
"from dynestyx import (\n",
" DynamicalModel,\n",
" ContinuousTimeStateEvolution,\n",
" Context,\n",
" Condition,\n",
" Trajectory,\n",
" DiracIdentityObservation,\n",
" LinearGaussianObservation,\n",
" DiscreteTimeSimulator,\n",
" FilterBasedMarginalLogLikelihood,\n",
" SDESimulator,\n",
" Discretizer,\n",
")\n"
]
},
{
Expand Down Expand Up @@ -234,20 +241,27 @@
"name": "stderr",
"output_type": "stream",
"text": [
"sample: 100%|██████████| 200/200 [00:00<00:00, 354.59it/s, 1 steps of size 3.10e-01. acc. prob=0.96]\n"
"sample: 100%|██████████| 200/200 [00:00<00:00, 564.98it/s, 1 steps of size 3.10e-01. acc. prob=0.96]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Discretizer + Dirac + NUTS: 2.08 seconds\n",
"Discretizer + Dirac + NUTS: 1.27 seconds\n",
"Posterior summary (rho):\n",
"Dirac:\n",
" mean = 27.76, std = 0.02\n",
"\n",
"True rho: 28.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
Expand Down Expand Up @@ -302,14 +316,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
"sample: 100%|██████████| 200/200 [01:44<00:00, 1.92it/s, 7 steps of size 3.30e-01. acc. prob=0.97] "
"sample: 100%|██████████| 200/200 [01:14<00:00, 2.67it/s, 7 steps of size 3.30e-01. acc. prob=0.97] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EnKF + NUTS: 109.41 seconds\n"
"EnKF + NUTS: 78.80 seconds\n"
]
},
{
Expand Down Expand Up @@ -341,7 +355,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 9,
"id": "timing",
"metadata": {},
"outputs": [
Expand All @@ -352,9 +366,9 @@
"==================================================\n",
"TIMING\n",
"==================================================\n",
"EnKF + NUTS: 109.41 s\n",
"Discretizer + Dirac + NUTS: 2.08 s\n",
"Speedup: 52.7x\n",
"EnKF + NUTS: 78.80 s\n",
"Discretizer + Dirac + NUTS: 1.27 s\n",
"Speedup: 62.2x\n",
"==================================================\n",
"==================================================\n",
"POSTERIOR SUMMARY\n",
Expand Down Expand Up @@ -393,14 +407,6 @@
"\n",
"print(\"\\nNote that the speedup scales with T (number of observations), \\nso a 10x longer timeseries would experience ~10x further speedup using Discretizer + Dirac. \\nFeel free to experiment.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df5fe87e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
3 changes: 2 additions & 1 deletion dynestyx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
FilterBasedHMMMarginalLogLikelihood,
FilterBasedMarginalLogLikelihood,
)
from dynestyx.handlers import Condition, sample
from dynestyx.handlers import Condition, Discretizer, sample
from dynestyx.observations import DiracIdentityObservation, LinearGaussianObservation
from dynestyx.simulators import DiscreteTimeSimulator, ODESimulator, SDESimulator

Expand All @@ -22,6 +22,7 @@
"ContinuousTimeStateEvolution",
"DiscreteTimeStateEvolution",
"DynamicalModel",
"Discretizer",
"ObservationModel",
"Trajectory",
"FilterBasedHMMMarginalLogLikelihood",
Expand Down
53 changes: 40 additions & 13 deletions dynestyx/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
_DISCRETE_FILTER_TYPES,
_filter_discrete_time,
)
from dynestyx.utils import _get_controls
from dynestyx.utils import _get_controls, _should_add_site

type SSMType = ContDiscreteNonlinearGaussianSSM | ContDiscreteNonlinearSSM

Expand All @@ -38,14 +38,28 @@ class FilterBasedMarginalLogLikelihoodObjIntp(BaseCDDynamaxLogFactorAdder):

key: jax.Array | None = None
filter_type: str = "default"
output_fields = None
filter_kwargs: dict = dataclasses.field(default_factory=dict)

def __init__(self, filter_type="default", output_fields=None, **filter_kwargs):
record_filtered_states_mean: bool = True
record_filtered_states_cov: bool = True
record_filtered_states_cov_diag: bool = True
record_filtered_particles: bool = True
record_filtered_log_weights: bool = True
record_filtered_states_chol_cov: bool = True
record_max_elems: int = 100_000

def __init__(self, filter_type="default", **filter_kwargs):
super().__init__()
self.filter_type = filter_type
self.output_fields = output_fields
self.filter_kwargs = filter_kwargs if filter_kwargs is not None else {}
self.record_kwargs = {
"record_filtered_states_mean": self.record_filtered_states_mean,
"record_filtered_states_cov": self.record_filtered_states_cov,
"record_filtered_states_cov_diag": self.record_filtered_states_cov_diag,
"record_filtered_particles": self.record_filtered_particles,
"record_filtered_log_weights": self.record_filtered_log_weights,
"record_filtered_states_chol_cov": self.record_filtered_states_chol_cov,
"record_max_elems": self.record_max_elems,
}

def add_log_factors(
self,
Expand All @@ -67,7 +81,13 @@ def add_log_factors(
f"Invalid filter type: {self.filter_type}. Valid types: {_CONTINUOUS_FILTER_TYPES}"
)
_filter_continuous_time(
name, self.filter_type, dynamics, context, self.key, self.filter_kwargs
name,
self.filter_type,
dynamics,
context,
self.key,
self.filter_kwargs,
self.record_kwargs,
)
else:
if self.filter_type.lower() not in _DISCRETE_FILTER_TYPES:
Expand All @@ -81,13 +101,9 @@ def add_log_factors(
context,
self.key,
self.filter_kwargs,
self.record_kwargs,
)

# numpyro.deterministic(f"{name}_filtered_states_mean", filtered.filtered_means)
# numpyro.deterministic(f"{name}_filtered_states_cov", filtered.filtered_covariances)
# numpyro.deterministic(f"{name}_predicted_states_mean", filtered.predicted_means)
# numpyro.deterministic(f"{name}_predicted_states_cov", filtered.predicted_covariances)


@handles(FilterBasedMarginalLogLikelihoodObjIntp)
def FilterBasedMarginalLogLikelihood( # type: ignore[empty-body]
Expand All @@ -106,6 +122,7 @@ class FilterBasedHMMMarginalLogLikelihoodObjIntp(BaseCDDynamaxLogFactorAdder):

record_filtered: bool = False
record_log_filtered: bool = False
record_max_elems: int = 100_000

def add_log_factors(
self,
Expand Down Expand Up @@ -138,17 +155,27 @@ def add_log_factors(
)

numpyro.factor(
f"{name}_marginal_log_likelihood",
loglik,
)

# For use in predictive sampling
numpyro.deterministic(
f"{name}_marginal_loglik",
loglik,
)

if self.record_log_filtered:
if self.record_log_filtered and _should_add_site(
log_filt_seq.shape, self.record_max_elems
):
numpyro.deterministic(
f"{name}_log_filtered_states",
log_filt_seq, # (T, K)
)

if self.record_filtered:
if self.record_filtered and _should_add_site(
log_filt_seq.shape, self.record_max_elems
):
numpyro.deterministic(
f"{name}_filtered_states",
jnp.exp(log_filt_seq), # (T, K)
Expand Down
36 changes: 35 additions & 1 deletion dynestyx/inference/cd_dynamax/continuous_time_filters.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import jax
import jax.numpy as jnp
import jax.random as jr
import numpyro
from cd_dynamax import ContDiscreteNonlinearGaussianSSM, ContDiscreteNonlinearSSM

from dynestyx.dynamical_models import Context, DynamicalModel
from dynestyx.inference.cd_dynamax.utils import dsx_to_cd_dynamax
from dynestyx.utils import _get_controls, _validate_control_dim
from dynestyx.utils import _get_controls, _should_add_site, _validate_control_dim

type SSMType = ContDiscreteNonlinearGaussianSSM | ContDiscreteNonlinearSSM

Expand All @@ -19,13 +20,17 @@ def _filter_continuous_time(
context: Context,
key: jax.Array | None = None,
filter_kwargs: dict | None = None,
record_kwargs: dict = {},
):
"""Continuous-time marginal likelihood via CD-Dynamax.

Args:
name: Name of the factor.
dynamics: Dynamical model to filter.
context: Context containing the observations and controls.
key: Random key for the filter.
filter_kwargs: Keyword arguments for the filter.
record_kwargs: Keyword arguments for recording the filtered states and their covariances.
"""

if filter_kwargs is None:
Expand Down Expand Up @@ -127,3 +132,32 @@ def _filter_continuous_time(

# Add the marginal log likelihood as a numpyro factor
numpyro.factor(f"{name}_marginal_log_likelihood", filtered.marginal_loglik)

# Add the marginal log likelihood as a deterministic site for easy access.
numpyro.deterministic(f"{name}_marginal_loglik", filtered.marginal_loglik)

# Optionally record the filtered states and their covariances as deterministic sites for easy access.
# Check dims before adding to protect against large arrays.
max_elems = record_kwargs.get("max_elems", 100_000)
means_shape = filtered.filtered_means.shape
cov_shape = filtered.filtered_covariances.shape

add_mean = record_kwargs.get(
"record_filtered_states_mean", False
) and _should_add_site(means_shape, max_elems)
add_cov = record_kwargs.get(
"record_filtered_states_cov", False
) and _should_add_site(cov_shape, max_elems)
add_cov_diag = record_kwargs.get(
"record_filtered_states_cov_diag", False
) and _should_add_site((cov_shape[0], cov_shape[1]), max_elems)

if add_mean:
numpyro.deterministic(f"{name}_filtered_states_mean", filtered.filtered_means)
if add_cov:
numpyro.deterministic(
f"{name}_filtered_states_cov", filtered.filtered_covariances
)
if add_cov_diag:
diag_cov = jnp.diagonal(filtered.filtered_covariances, axis1=1, axis2=2)
numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov)
Loading