Conversation
|
SSMProblems.jl/SSMProblems documentation for PR #134 is available at: |
|
SSMProblems.jl/GeneralisedFilters documentation for PR #134 is available at: |
| StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" | ||
| Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
| StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
| Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" |
There was a problem hiding this comment.
This is a really heavy dep. Have you considered putting it in an extension?
| Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
| StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
| Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" | ||
| Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" |
There was a problem hiding this comment.
I didn't look super duper carefully, but is there a reason Zygote and ForwardDiff need to be deps? Ideally ADTypes should be the only AD-related dependency.
|
I'm happy to review the Turing parts in more detail. This was the first thing that stood out, though, and I think it's better for me to review the code after it's been reorganised (unless there's a reason why that's not possible). |
|
Hey Penelope! Thanks for having a scan through. Yes, probably best waiting a few days until this is cleaned up a bit more, but would definitely appreciate your feedback then. As for the deps, that was just so I could get a prototype up quickly. The AD packages will be moved to test dependencies and Turing to a package extension. |
|
Sure, of course! For some reason I thought it was ready to go, sorry :) Feel free to ping whenever you have it more finalised. |
f34cfb1 to
1ce25c6
Compare
| function kf_loglikelihood(μ0, Σ0, As, bs, Qs, Hs, cs, Rs, ys) | ||
| T = length(ys) | ||
| state = MvNormal(μ0, Σ0) | ||
| ll = zero(eltype(μ0)) | ||
|
|
||
| for t in 1:T | ||
| state = kalman_predict(state, (As[t], bs[t], Qs[t])) | ||
| state, ll_inc, _ = _kalman_update_cached(state, Hs[t], cs[t], Rs[t], ys[t], nothing) | ||
| ll += ll_inc | ||
| end | ||
|
|
||
| return ll | ||
| end |
There was a problem hiding this comment.
It may not be the case that state generated from MvNormal(μ0, Σ0) is the same type as it would be from state = kalman_predict(state, (As[t], bs[t], Qs[t])). I've had issues with Zygote in particular when integrating patterns like this (see here). It may be as simple as separating the for loop to first allocate state and ll at time t=1 then looping over the rest since they guarantee type stability from that point forward.
| function trajectory_logdensity( | ||
| model::StateSpaceModel, trajectory, observations::AbstractVector | ||
| ) | ||
| T = length(observations) | ||
| ll = logpdf(SSMProblems.distribution(prior(model)), trajectory[0]) | ||
| for t in 1:T | ||
| ll += SSMProblems.logdensity(dyn(model), t, trajectory[t - 1], trajectory[t]) | ||
| ll += SSMProblems.logdensity(obs(model), t, trajectory[t], observations[t]) | ||
| end | ||
| return ll | ||
| end |
There was a problem hiding this comment.
Could be a similar issue here. When testing MLE, ForwardDiff would fail here in cases where the element type of prior(model) has no Duals to propagate through the for loop
d0bb210 to
100ac33
Compare
This PR provides support Particle Gibbs using Turing.jl, backed by GeneralisedFilter's CSMC implementation.
It works in both the regular and Rao-Blackwellised case, using a custom
rrulefor the gradient of the Kalman filter log-likelihood with respect to the model parameters.The final use is something like this
The PR is quite expansive but the individual commits are largely semantically meaningful. The implementation was broken up into a few phases:
A few caveats about the code:
That said...it does work!
Will close #67
Will close #48