Skip to content

Turing.jl Integration#134

Open
THargreaves wants to merge 19 commits intomainfrom
th/pgas-nuts
Open

Turing.jl Integration#134
THargreaves wants to merge 19 commits intomainfrom
th/pgas-nuts

Conversation

@THargreaves
Copy link
Collaborator

@THargreaves THargreaves commented Feb 17, 2026

This PR provides support Particle Gibbs using Turing.jl, backed by GeneralisedFilter's CSMC implementation.

It works in both the regular and Rao-Blackwellised case, using a custom rrule for the gradient of the Kalman filter log-likelihood with respect to the model parameters.

The final use is something like this

@model function my_model(ys)
    θ ~ [Priors]
    ssm = my_ssm_constructor_function(θ)
    x ~ SSMTrajectory(ssm, ys)
end

m = my_model(ys)
pg = ParticleGibbs(
    CSMC(RBPF(BF(N_particles), KF())), NUTS(0.8); adtype=ADTypes.AutoZygote()
)

chain = AbstractMCMC.sample(
    rng, m, pg, N_iter; n_adapts=N_adapts, progress=false, chain_type=MCMCChains.Chains
)

The PR is quite expansive but the individual commits are largely semantically meaningful. The implementation was broken up into a few phases:

  1. Defining CSMC samplers using the AbstractMCMC interface with no refreshment and ancestor sampling
  2. Adding backward simulation and sparse particle storage
  3. Defining the LogDensityProblems interface for regular and RB SSMs
  4. Implementing chain rules for the Kalman filter gradient
  5. Implementing a ParticleGibbs sampler for a manually defined SSM
  6. Extending this to a Turing @model defined SSM

A few caveats about the code:

  • I am not at all experienced with DynamicPPL/AbstractMCMC/LogDensityProblems so I might have implemented things weirdly. I'm more than happy to refactor if people point out issues
  • The gradients for the RBPF only work when using reverse-mode diff. See ForwardDiff support for analytical filter gradients #132
  • The implementation is painfully slow and likely contains a fair bit of type instability

That said...it does work!

Will close #67
Will close #48

@github-actions
Copy link
Contributor

SSMProblems.jl/SSMProblems documentation for PR #134 is available at:
https://TuringLang.github.io/SSMProblems.jl/SSMProblems/previews/PR134/

@github-actions
Copy link
Contributor

SSMProblems.jl/GeneralisedFilters documentation for PR #134 is available at:
https://TuringLang.github.io/SSMProblems.jl/GeneralisedFilters/previews/PR134/

StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really heavy dep. Have you considered putting it in an extension?

Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't look super duper carefully, but is there a reason Zygote and ForwardDiff need to be deps? Ideally ADTypes should be the only AD-related dependency.

@penelopeysm
Copy link
Member

I'm happy to review the Turing parts in more detail. This was the first thing that stood out, though, and I think it's better for me to review the code after it's been reorganised (unless there's a reason why that's not possible).

@THargreaves
Copy link
Collaborator Author

Hey Penelope! Thanks for having a scan through. Yes, probably best waiting a few days until this is cleaned up a bit more, but would definitely appreciate your feedback then.

As for the deps, that was just so I could get a prototype up quickly. The AD packages will be moved to test dependencies and Turing to a package extension.

@penelopeysm
Copy link
Member

Sure, of course! For some reason I thought it was ready to go, sorry :) Feel free to ping whenever you have it more finalised.

Comment on lines +148 to +160
function kf_loglikelihood(μ0, Σ0, As, bs, Qs, Hs, cs, Rs, ys)
T = length(ys)
state = MvNormal(μ0, Σ0)
ll = zero(eltype(μ0))

for t in 1:T
state = kalman_predict(state, (As[t], bs[t], Qs[t]))
state, ll_inc, _ = _kalman_update_cached(state, Hs[t], cs[t], Rs[t], ys[t], nothing)
ll += ll_inc
end

return ll
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may not be the case that state generated from MvNormal(μ0, Σ0) is the same type as it would be from state = kalman_predict(state, (As[t], bs[t], Qs[t])). I've had issues with Zygote in particular when integrating patterns like this (see here). It may be as simple as separating the for loop to first allocate state and ll at time t=1 then looping over the rest since they guarantee type stability from that point forward.

Comment on lines +18 to +28
function trajectory_logdensity(
model::StateSpaceModel, trajectory, observations::AbstractVector
)
T = length(observations)
ll = logpdf(SSMProblems.distribution(prior(model)), trajectory[0])
for t in 1:T
ll += SSMProblems.logdensity(dyn(model), t, trajectory[t - 1], trajectory[t])
ll += SSMProblems.logdensity(obs(model), t, trajectory[t], observations[t])
end
return ll
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be a similar issue here. When testing MLE, ForwardDiff would fail here in cases where the element type of prior(model) has no Duals to propagate through the for loop

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AbstractMCMC/Turing integration Wrote unit tests for Particle Gibbs

3 participants