Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module DynamicPPLMCMCChainsExt

using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random
using BangBang: setindex!!
using MCMCChains: MCMCChains

function getindex_varname(
Expand Down Expand Up @@ -82,7 +83,7 @@ end
"""
AbstractMCMC.to_samples(
::Type{DynamicPPL.ParamsWithStats},
chain::MCMCChains.Chains
chain::MCMCChains.Chains,
)

Convert an `MCMCChains.Chains` object to an array of `DynamicPPL.ParamsWithStats`.
Expand All @@ -95,11 +96,11 @@ function AbstractMCMC.to_samples(
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
# Get parameters
params_matrix = map(idxs) do (sample_idx, chain_idx)
d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}()
vnt = DynamicPPL.VarNamedTuple()
for vn in get_varnames(chain)
d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx)
vnt = setindex!!(vnt, getindex_varname(chain, sample_idx, vn, chain_idx), vn)
end
d
vnt
end
# Statistics
stats_matrix = if :internals in MCMCChains.sections(chain)
Expand Down Expand Up @@ -164,8 +165,8 @@ end
fallback=nothing,
)

Re-evaluate `model` for each sample in `chain` using the accumulators provided in `at`,
returning an matrix of `(retval, updated_at)` tuples.
Re-evaluate `model` for each sample in `chain` using the accumulators provided in `accs`,
returning a matrix of `(retval, updated_at)` tuples.

This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the
initialisation strategy when re-evaluating the model. For many usecases the fallback should
Expand Down
2 changes: 1 addition & 1 deletion src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ A struct which contains parameter values extracted from a `VarInfo`, along with
statistics associated with the VarInfo. The statistics are provided as a NamedTuple and are
optional.
"""
struct ParamsWithStats{P<:OrderedDict{<:VarName,<:Any},S<:NamedTuple}
struct ParamsWithStats{P<:VarNamedTuple,S<:NamedTuple}
params::P
stats::S
end
Expand Down
4 changes: 2 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,14 @@ end

function generate_assign(left, right)
# A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for
# ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator.
# ValuesAsInModel then in addition we push!! the pair of `x` and `y` to the accumulator.
@gensym acc right_val vn
return quote
$right_val = $right
if $(DynamicPPL.is_extracting_values)(__varinfo__)
$vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left)))
__varinfo__ = $(map_accumulator!!)(
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
$acc -> push!!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
)
end
$left = $right_val
Expand Down
2 changes: 1 addition & 1 deletion src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ InitFromParams(params) = InitFromParams(params, InitFromPrior())

function init(
rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P}
) where {P<:Union{AbstractDict{<:VarName},NamedTuple}}
) where {P<:Union{AbstractDict{<:VarName},NamedTuple,VarNamedTuple}}
# TODO(penelopeysm): It would be nice to do a check to make sure that all
# of the parameters in `p.params` were actually used, and either warn or
# error if they aren't. This is actually quite non-trivial though because
Expand Down
32 changes: 30 additions & 2 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@
#
# Some additionally contain an implementation of `rand_prior_true`.

"""
varnames(model::Model)

Return the VarNames defined in `model`, as a Vector.
"""
function varnames end

# TODO(mhauru) The fact that the below function exists is a sign that we are inconsistent in
# how we handle IndexLenses. This should hopefully be resolved once we consistently use
# VarNamedTuple rather than dictionaries everywhere.
"""
varnames_split(model::Model)

Return the VarNames in `model`, with any ranges or colons split into individual indices.

The default implementation is to just return `varnames(model)`. If something else is needed,
this should be defined separately.
"""
varnames_split(model::Model) = varnames(model)

"""
demo_dynamic_constraint()

Expand Down Expand Up @@ -77,6 +97,9 @@ end
function varnames(model::Model{typeof(demo_one_variable_multiple_constraints)})
return [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(x[4:5])]
end
function varnames_split(model::Model{typeof(demo_one_variable_multiple_constraints)})
return [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(x[4]), @varname(x[5])]
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_one_variable_multiple_constraints)}, x
)
Expand Down Expand Up @@ -624,8 +647,13 @@ function varnames(::Model{typeof(demo_nested_colons)})
AbstractPPL.ConcretizedSlice(Base.Slice(Base.OneTo(2))),
]
),
# @varname(s.params[1].subparams[1,1,1]),
# @varname(s.params[1].subparams[1,1,2]),
@varname(m),
]
end
function varnames_split(::Model{typeof(demo_nested_colons)})
return [
@varname(s.params[1].subparams[1, 1, 1]),
@varname(s.params[1].subparams[1, 1, 2]),
@varname(m),
]
end
Expand Down
20 changes: 14 additions & 6 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ wants to extract the realization of a model in a constrained space.
# Fields
$(TYPEDFIELDS)
"""
struct ValuesAsInModelAccumulator <: AbstractAccumulator
struct ValuesAsInModelAccumulator{VNT<:VarNamedTuple} <: AbstractAccumulator
"values that are extracted from the model"
values::OrderedDict{<:VarName}
values::VNT
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
end
function ValuesAsInModelAccumulator(include_colon_eq)
return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq)
return ValuesAsInModelAccumulator(VarNamedTuple(), include_colon_eq)
end

function Base.:(==)(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator)
Expand All @@ -30,6 +30,9 @@ end

accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel

# TODO(mhauru) We could start using reset!!, which could call empty!! on the VarNamedTuple.
# This would create VarNamedTuples that share memory with the original one, saving
# allocations but also making them not capable of taking in any arbitrary VarName.
function _zero(acc::ValuesAsInModelAccumulator)
return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq)
end
Expand All @@ -45,8 +48,11 @@ function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumula
)
end

function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val)
setindex!(acc.values, deepcopy(val), vn)
function BangBang.push!!(acc::ValuesAsInModelAccumulator, vn::VarName, val)
# TODO(mhauru) The deepcopy here is quite unfortunate. It is needed so that the model
# body can go mutating the object without that reactively affecting the value in the
# accumulator, which should be as it was at `~` time. Could there be a way around this?
Accessors.@reset acc.values = setindex!!(acc.values, deepcopy(val), vn)
return acc
end

Expand All @@ -56,7 +62,7 @@ function is_extracting_values(vi::AbstractVarInfo)
end

function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right)
return push!(acc, vn, val)
return push!!(acc, vn, val)
end

accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc
Expand All @@ -75,6 +81,8 @@ working in unconstrained space.
Hence this method is a "safe" way of obtaining realizations in constrained
space at the cost of additional model evaluations.

Returns a `VarNamedTuple`.

# Arguments
- `model::Model`: model to extract realizations from.
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
Expand Down
Loading