Skip to content

Plate-Based Primitive for Hierarchical Data#173

Merged
mattlevine22 merged 58 commits intomainfrom
dw-hierarchical-trajectories
Apr 9, 2026
Merged

Plate-Based Primitive for Hierarchical Data#173
mattlevine22 merged 58 commits intomainfrom
dw-hierarchical-trajectories

Conversation

@DanWaxman
Copy link
Copy Markdown
Collaborator

@DanWaxman DanWaxman commented Mar 23, 2026

Add hierarchical trajectory support across sampling, filtering, and simulation, allowing for multiple-trajectory and mixed effect modelling via a new plate primitive.

Example:

with dsx.plate("groups", G):
    beta = numpyro.sample("beta", dist.Normal(0.0, 0.3))

    with dsx.plate("trajectories", M):
        alpha = numpyro.sample("alpha", dist.Uniform(-0.7, 0.7))

        A = jnp.repeat(A_base[None], M, axis=0).at[:, 0, 0].set(alpha)
        dynamics = LTI_discrete(A=A, Q=Q, H=H, R=R)

        dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values)

Nested plates like groups -> trajectories now work with plated parameters and inference

High-level changes

  • A new sample interpretation, dsx.plate
    • This inherits numpyro plating semantics, attaches metadata, and passes forward.
  • Made models and validators plate-aware, including automatic bm_dim inference in plated continuous-time models
  • Extended filtering to run over plated batches for:
    • continuous-time CD-dynamax-backed paths
    • discrete-time CD-dynamax-backed paths
    • HMM
    • cuthbert-backed paths
  • Extended simulators to slice per-plate inputs, run each member, and stack results back into plated outputs
  • Added a new hierarchical inference tutorial and updated API/docs navigation
  • Added smoke and science tests for hierarchical simulation and inference

Testing

  • Added coverage for plated bm_dim inference
  • Added hierarchical smoke tests for simulators, inference, and SVI
  • Added science tests for hierarchical HMM, LTI, nonlinear EKF, ODE, and simulator inference

@DanWaxman DanWaxman marked this pull request as ready for review April 7, 2026 12:54
@DanWaxman DanWaxman requested a review from mattlevine22 April 7, 2026 12:55
Copy link
Copy Markdown
Collaborator

@mattlevine22 mattlevine22 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small check on the pyproject.

Other (bigger) ask is to post on slack the results from test_science especially as you've added new science tests

Comment thread pyproject.toml
@DanWaxman
Copy link
Copy Markdown
Collaborator Author

I ran all the new tests and had Codex generate a small report. I left it on the Slack, but in the spirit of open development, it can also be found here: https://drive.google.com/file/d/1NVFHmKwQMHz-N8g3lWdQjnm5nCKXzCOE/view?usp=sharing

@mattlevine22
Copy link
Copy Markdown
Collaborator

All the plots for the science tests look good.

I think we should add science test coverage for ODESimulator and Discretizer

Can we add 2 more tests?

  • ODE (maybe just multiple trajectories and no parameter hierarchy?)
  • Can re-use a continuous-time example but apply Discretizer.

@DanWaxman
Copy link
Copy Markdown
Collaborator Author

Discretizer test was added and looks good.
image

There was actually already an ODE one that I forgot to include in the "report," but I edited it to get a bit better posteriors/run faster. The results don't look amazing, but this example may just be something the ODEs struggle with.
image

@DanWaxman DanWaxman requested a review from mattlevine22 April 8, 2026 16:56
@mattlevine22
Copy link
Copy Markdown
Collaborator

Hmmm, I think it is worth cooking up an ODE example that makes it more clear that it is "working". From the current plot + eyeball norm, I'm not seeing a significant adaption towards the truth in each trajectory....this could just be challenging inference OR it could be that data/information is incorrectly routed.

@DanWaxman
Copy link
Copy Markdown
Collaborator Author

Fair enough! I think that coupled example is just not super identifiable for ODEs. Here's an alternative damped oscillator model:

image

Copy link
Copy Markdown
Collaborator

@mattlevine22 mattlevine22 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome work, thanks @DanWaxman! We'll now have support for hierarchical and multi-trajectory inference!

@mattlevine22 mattlevine22 merged commit 762dcb6 into main Apr 9, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support (hierarchical) inference over multiple trajectories

2 participants